From: Mike Bayer Date: Wed, 3 Oct 2018 14:40:38 +0000 (-0400) Subject: Support tuples of heterogeneous types for empty expanding IN X-Git-Tag: rel_1_3_0b1~52^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=aa2128427064a2bdeaeff5dc946ecbb3727c90aa;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support tuples of heterogeneous types for empty expanding IN Pass a list of all the types for the left side of an IN expression to the visit_empty_set_expr() method, so that the "empty expanding IN" can produce clauses for each element. Fixes: #4271 Change-Id: I2738b9df2292ac01afda37f16d4fa56ae7bf9147 --- diff --git a/doc/build/changelog/migration_13.rst b/doc/build/changelog/migration_13.rst index 5a8e3ce05b..40363d4268 100644 --- a/doc/build/changelog/migration_13.rst +++ b/doc/build/changelog/migration_13.rst @@ -413,6 +413,24 @@ backend, such as "SELECT CAST(NULL AS INTEGER) WHERE 1!=1" for Postgresql, ... SELECT 1 WHERE 1 IN (SELECT CAST(NULL AS INTEGER) WHERE 1!=1) +The feature also works for tuple-oriented IN statements, where the "empty IN" +expression will be expanded to support the elements given inside the tuple, +such as on Postgresql:: + + >>> from sqlalchemy import create_engine + >>> from sqlalchemy import select, literal_column, tuple_, bindparam + >>> e = create_engine("postgresql://scott:tiger@localhost/test", echo=True) + >>> with e.connect() as conn: + ... conn.execute( + ... select([literal_column('1')]). + ... where(tuple_(50, "somestring").in_(bindparam('q', expanding=True))), + ... q=[] + ... ) + ... + SELECT 1 WHERE (%(param_1)s, %(param_2)s) + IN (SELECT CAST(NULL AS INTEGER), CAST(NULL AS VARCHAR) WHERE 1!=1) + + :ticket:`4271` .. _change_3981: diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 45e8c25104..43966d1dc8 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1179,8 +1179,18 @@ class MySQLCompiler(compiler.SQLCompiler): fromhints=from_hints, **kw) for t in [from_table] + extra_froms) - def visit_empty_set_expr(self, type_): - return 'SELECT 1 FROM (SELECT 1) as _empty_set WHERE 1!=1' + def visit_empty_set_expr(self, element_types): + return ( + "SELECT %(outer)s FROM (SELECT %(inner)s) " + "as _empty_set WHERE 1!=1" % { + "inner": ", ".join( + "1 AS _in_%s" % idx + for idx, type_ in enumerate(element_types)), + "outer": ", ".join( + "_in_%s" % idx + for idx, type_ in enumerate(element_types)) + } + ) class MySQLDDLCompiler(compiler.DDLCompiler): diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 11fcc41d59..5251a000da 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1485,14 +1485,17 @@ class PGCompiler(compiler.SQLCompiler): if escape else '' ) - def visit_empty_set_expr(self, type_, **kw): + def visit_empty_set_expr(self, element_types): # cast the empty set to the type we are comparing against. if # we are comparing against the null type, pick an arbitrary # datatype for the empty set - if type_._isnull: - type_ = INTEGER() - return 'SELECT CAST(NULL AS %s) WHERE 1!=1' % \ - self.dialect.type_compiler.process(type_, **kw) + return 'SELECT %s WHERE 1!=1' % ( + ", ".join( + "CAST(NULL AS %s)" % self.dialect.type_compiler.process( + INTEGER() if type_._isnull else type_, + ) for type_ in element_types or [INTEGER()] + ), + ) def render_literal_value(self, value, type_): value = super(PGCompiler, self).render_literal_value(value, type_) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index f48217a4e9..5c96e4240e 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -737,7 +737,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext): to_update = [] replacement_expressions[name] = ( self.compiled.visit_empty_set_expr( - type_=parameter.type) + parameter._expanding_in_types + if parameter._expanding_in_types + else [parameter.type] + ) ) elif isinstance(values[0], (tuple, list)): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 2f68b7e2e4..27ee4afc68 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1056,7 +1056,7 @@ class SQLCompiler(Compiled): self._emit_empty_in_warning() return self.process(binary.left == binary.left) - def visit_empty_set_expr(self, type_): + def visit_empty_set_expr(self, element_types): raise NotImplementedError( "Dialect '%s' does not support empty set expression." % self.dialect.name diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 5d02f65a15..8149f9731d 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -15,7 +15,8 @@ from .elements import BindParameter, True_, False_, BinaryExpression, \ Null, _const_expr, _clause_element_as_expr, \ ClauseList, ColumnElement, TextClause, UnaryExpression, \ collate, _is_literal, _literal_as_text, ClauseElement, and_, or_, \ - Slice, Visitable, _literal_as_binds, CollectionAggregate + Slice, Visitable, _literal_as_binds, CollectionAggregate, \ + Tuple from .selectable import SelectBase, Alias, Selectable, ScalarSelect @@ -145,6 +146,14 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): elif isinstance(seq_or_selectable, ClauseElement): if isinstance(seq_or_selectable, BindParameter) and \ seq_or_selectable.expanding: + + if isinstance(expr, Tuple): + seq_or_selectable = ( + seq_or_selectable._with_expanding_in_types( + [elem.type for elem in expr] + ) + ) + return _boolean_compare( expr, op, seq_or_selectable, diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index dd16b68628..de3b7992af 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -865,6 +865,7 @@ class BindParameter(ColumnElement): __visit_name__ = 'bindparam' _is_crud = False + _expanding_in_types = () def __init__(self, key, value=NO_ARG, type_=None, unique=False, required=NO_ARG, @@ -1134,6 +1135,15 @@ class BindParameter(ColumnElement): else: self.type = type_ + def _with_expanding_in_types(self, types): + """Return a copy of this :class:`.BindParameter` in + the context of an expanding IN against a tuple. + + """ + cloned = self._clone() + cloned._expanding_in_types = types + return cloned + def _with_value(self, value): """Return a copy of this :class:`.BindParameter` with the given value set. diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index 78b34f4962..73ce02492f 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -402,6 +402,34 @@ class ExpandingBoundInTest(fixtures.TablesTest): params={"q": [], "p": []}, ) + @testing.requires.tuple_in + def test_empty_heterogeneous_tuples(self): + table = self.tables.some_table + + stmt = select([table.c.id]).where( + tuple_(table.c.x, table.c.z).in_( + bindparam('q', expanding=True))).order_by(table.c.id) + + self._assert_result( + stmt, + [], + params={"q": []}, + ) + + @testing.requires.tuple_in + def test_empty_homogeneous_tuples(self): + table = self.tables.some_table + + stmt = select([table.c.id]).where( + tuple_(table.c.x, table.c.y).in_( + bindparam('q', expanding=True))).order_by(table.c.id) + + self._assert_result( + stmt, + [], + params={"q": []}, + ) + def test_bound_in_scalar(self): table = self.tables.some_table @@ -428,6 +456,20 @@ class ExpandingBoundInTest(fixtures.TablesTest): params={"q": [(2, 3), (3, 4), (4, 5)]}, ) + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple(self): + table = self.tables.some_table + + stmt = select([table.c.id]).where( + tuple_(table.c.x, table.c.z).in_( + bindparam('q', expanding=True))).order_by(table.c.id) + + self._assert_result( + stmt, + [(2, ), (3, ), (4, )], + params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, + ) + def test_empty_set_against_integer(self): table = self.tables.some_table