From: Mike Bayer Date: Fri, 22 Apr 2022 14:57:00 +0000 (-0400) Subject: properly type array element in any() / all() X-Git-Tag: rel_1_4_36~4^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=eb7061ea7d133eb3154a825595ef31df47f1ced2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git properly type array element in any() / all() Fixed bug in :class:`.ARRAY` datatype in combination with :class:`.Enum` on PostgreSQL where using the ``.any()`` method to render SQL ANY(), given members of the Python enumeration as arguments, would produce a type adaptation failure on all drivers. Fixes: #6515 Change-Id: Ia1e3b4e10aaf264ed436ce6030d105fc60023433 (cherry picked from commit d023c8e1c7ad82fb249fab5155eb83dee17a160c) --- diff --git a/doc/build/changelog/unreleased_14/6515.rst b/doc/build/changelog/unreleased_14/6515.rst new file mode 100644 index 0000000000..0ac5332b55 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6515.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, postgresql + :tickets: 6515 + + Fixed bug in :class:`.ARRAY` datatype in combination with :class:`.Enum` on + PostgreSQL where using the ``.any()`` method to render SQL ANY(), given + members of the Python enumeration as arguments, would produce a type + adaptation failure on all drivers. diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index e51397da7f..92aaf1c57d 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2851,10 +2851,18 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): elements = util.preloaded.sql_elements operator = operator if operator else operators.eq + arr_type = self.type + # send plain BinaryExpression so that negate remains at None, # leading to NOT expr for negation. return elements.BinaryExpression( - coercions.expect(roles.ExpressionElementRole, other), + coercions.expect( + roles.BinaryElementRole, + element=other, + operator=operator, + expr=self.expr, + bindparam_type=arr_type.item_type, + ), elements.CollectionAggregate._create_any(self.expr), operator, ) @@ -2895,10 +2903,18 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): elements = util.preloaded.sql_elements operator = operator if operator else operators.eq + arr_type = self.type + # send plain BinaryExpression so that negate remains at None, # leading to NOT expr for negation. return elements.BinaryExpression( - coercions.expect(roles.ExpressionElementRole, other), + coercions.expect( + roles.BinaryElementRole, + element=other, + operator=operator, + expr=self.expr, + bindparam_type=arr_type.item_type, + ), elements.CollectionAggregate._create_all(self.expr), operator, ) diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 49ab15261e..6bd2f2fa2b 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1523,48 +1523,48 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_compile( postgresql.Any(4, c), - "%(param_1)s = ANY (x)", - checkparams={"param_1": 4}, + "%(x_1)s = ANY (x)", + checkparams={"x_1": 4}, ) self.assert_compile( c.any(5), - "%(param_1)s = ANY (x)", - checkparams={"param_1": 5}, + "%(x_1)s = ANY (x)", + checkparams={"x_1": 5}, ) self.assert_compile( ~c.any(5), - "NOT (%(param_1)s = ANY (x))", - checkparams={"param_1": 5}, + "NOT (%(x_1)s = ANY (x))", + checkparams={"x_1": 5}, ) self.assert_compile( c.all(5), - "%(param_1)s = ALL (x)", - checkparams={"param_1": 5}, + "%(x_1)s = ALL (x)", + checkparams={"x_1": 5}, ) self.assert_compile( ~c.all(5), - "NOT (%(param_1)s = ALL (x))", - checkparams={"param_1": 5}, + "NOT (%(x_1)s = ALL (x))", + checkparams={"x_1": 5}, ) self.assert_compile( c.any(5, operator=operators.ne), - "%(param_1)s != ANY (x)", - checkparams={"param_1": 5}, + "%(x_1)s != ANY (x)", + checkparams={"x_1": 5}, ) self.assert_compile( postgresql.All(6, c, operator=operators.gt), - "%(param_1)s > ALL (x)", - checkparams={"param_1": 6}, + "%(x_1)s > ALL (x)", + checkparams={"x_1": 6}, ) self.assert_compile( c.all(7, operator=operators.lt), - "%(param_1)s < ALL (x)", - checkparams={"param_1": 7}, + "%(x_1)s < ALL (x)", + checkparams={"x_1": 7}, ) @testing.combinations( diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index fe39672627..ad0fcfeeea 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1261,16 +1261,16 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): col = column("x", postgresql.ARRAY(Integer)) self.assert_compile( select(col.any(7, operator=operators.lt)), - "SELECT %(param_1)s < ANY (x) AS anon_1", - checkparams={"param_1": 7}, + "SELECT %(x_1)s < ANY (x) AS anon_1", + checkparams={"x_1": 7}, ) def test_array_all(self): col = column("x", postgresql.ARRAY(Integer)) self.assert_compile( select(col.all(7, operator=operators.lt)), - "SELECT %(param_1)s < ALL (x) AS anon_1", - checkparams={"param_1": 7}, + "SELECT %(x_1)s < ALL (x) AS anon_1", + checkparams={"x_1": 7}, ) def test_array_contains(self): @@ -2397,14 +2397,19 @@ class ArrayEnum(fixtures.TestBase): array_cls(enum_cls(MyEnum)), ), ) + data = [ + {"enum_col": ["foo"], "pyenum_col": [MyEnum.a, MyEnum.b]}, + {"enum_col": ["foo", "bar"], "pyenum_col": [MyEnum.b]}, + ] else: MyEnum = None + data = [ + {"enum_col": ["foo"]}, + {"enum_col": ["foo", "bar"]}, + ] metadata.create_all(connection) - connection.execute( - tbl.insert(), - [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}], - ) + connection.execute(tbl.insert(), data) return tbl, MyEnum yield go @@ -2421,6 +2426,27 @@ class ArrayEnum(fixtures.TestBase): )(fn) ) + @testing.requires.python3 + @_enum_combinations + @testing.combinations("all", "any", argnames="fn") + def test_any_all_roundtrip( + self, array_of_enum_fixture, connection, array_cls, enum_cls, fn + ): + """test #6515""" + + tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls) + + if fn == "all": + expr = tbl.c.pyenum_col.all(MyEnum.b) + result = [([MyEnum.b],)] + elif fn == "any": + expr = tbl.c.pyenum_col.any(MyEnum.b) + result = [([MyEnum.a, MyEnum.b],), ([MyEnum.b],)] + else: + assert False + sel = select(tbl.c.pyenum_col).where(expr).order_by(tbl.c.id) + eq_(connection.execute(sel).fetchall(), result) + @_enum_combinations def test_array_of_enums_roundtrip( self, array_of_enum_fixture, connection, array_cls, enum_cls diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 4eff872f4f..c524b0aeaa 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -3544,8 +3544,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( t.c.arrval.any(5, operator.gt), - ":param_1 > ANY (tab1.arrval)", - checkparams={"param_1": 5}, + ":arrval_1 > ANY (tab1.arrval)", + checkparams={"arrval_1": 5}, ) def test_any_array_comparator_negate_accessor(self, t_fixture): @@ -3553,8 +3553,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( ~t.c.arrval.any(5, operator.gt), - "NOT (:param_1 > ANY (tab1.arrval))", - checkparams={"param_1": 5}, + "NOT (:arrval_1 > ANY (tab1.arrval))", + checkparams={"arrval_1": 5}, ) def test_all_array_comparator_accessor(self, t_fixture): @@ -3562,8 +3562,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( t.c.arrval.all(5, operator.gt), - ":param_1 > ALL (tab1.arrval)", - checkparams={"param_1": 5}, + ":arrval_1 > ALL (tab1.arrval)", + checkparams={"arrval_1": 5}, ) def test_all_array_comparator_negate_accessor(self, t_fixture): @@ -3571,8 +3571,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( ~t.c.arrval.all(5, operator.gt), - "NOT (:param_1 > ALL (tab1.arrval))", - checkparams={"param_1": 5}, + "NOT (:arrval_1 > ALL (tab1.arrval))", + checkparams={"arrval_1": 5}, ) def test_any_array_expression(self, t_fixture):