From: Mike Bayer Date: Fri, 22 Apr 2022 14:57:00 +0000 (-0400) Subject: properly type array element in any() / all() X-Git-Tag: rel_2_0_0b1~336 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=63191fbef63ebfbf57e7b66bd6529305fc62c605;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git properly type array element in any() / all() Fixed bug in :class:`.ARRAY` datatype in combination with :class:`.Enum` on PostgreSQL where using the ``.any()`` method to render SQL ANY(), given members of the Python enumeration as arguments, would produce a type adaptation failure on all drivers. Fixes: #6515 Change-Id: Ia1e3b4e10aaf264ed436ce6030d105fc60023433 --- diff --git a/doc/build/changelog/unreleased_14/6515.rst b/doc/build/changelog/unreleased_14/6515.rst new file mode 100644 index 0000000000..0ac5332b55 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6515.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, postgresql + :tickets: 6515 + + Fixed bug in :class:`.ARRAY` datatype in combination with :class:`.Enum` on + PostgreSQL where using the ``.any()`` method to render SQL ANY(), given + members of the Python enumeration as arguments, would produce a type + adaptation failure on all drivers. diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 64d6ea81be..65b97d5650 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2715,9 +2715,11 @@ class ARRAY( __slots__ = () + type: ARRAY + def _setup_getitem(self, index): - arr_type = cast(ARRAY, self.type) + arr_type = self.type return_type: TypeEngine[Any] @@ -2784,10 +2786,18 @@ class ARRAY( elements = util.preloaded.sql_elements operator = operator if operator else operators.eq + arr_type = self.type + # send plain BinaryExpression so that negate remains at None, # leading to NOT expr for negation. return elements.BinaryExpression( - coercions.expect(roles.ExpressionElementRole, other), + coercions.expect( + roles.BinaryElementRole, + element=other, + operator=operator, + expr=self.expr, + bindparam_type=arr_type.item_type, + ), elements.CollectionAggregate._create_any(self.expr), operator, ) @@ -2828,10 +2838,18 @@ class ARRAY( elements = util.preloaded.sql_elements operator = operator if operator else operators.eq + arr_type = self.type + # send plain BinaryExpression so that negate remains at None, # leading to NOT expr for negation. return elements.BinaryExpression( - coercions.expect(roles.ExpressionElementRole, other), + coercions.expect( + roles.BinaryElementRole, + element=other, + operator=operator, + expr=self.expr, + bindparam_type=arr_type.item_type, + ), elements.CollectionAggregate._create_all(self.expr), operator, ) diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 2221fd30a8..0fe5f70660 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1506,48 +1506,48 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_compile( postgresql.Any(4, c), - "%(param_1)s = ANY (x)", - checkparams={"param_1": 4}, + "%(x_1)s = ANY (x)", + checkparams={"x_1": 4}, ) self.assert_compile( c.any(5), - "%(param_1)s = ANY (x)", - checkparams={"param_1": 5}, + "%(x_1)s = ANY (x)", + checkparams={"x_1": 5}, ) self.assert_compile( ~c.any(5), - "NOT (%(param_1)s = ANY (x))", - checkparams={"param_1": 5}, + "NOT (%(x_1)s = ANY (x))", + checkparams={"x_1": 5}, ) self.assert_compile( c.all(5), - "%(param_1)s = ALL (x)", - checkparams={"param_1": 5}, + "%(x_1)s = ALL (x)", + checkparams={"x_1": 5}, ) self.assert_compile( ~c.all(5), - "NOT (%(param_1)s = ALL (x))", - checkparams={"param_1": 5}, + "NOT (%(x_1)s = ALL (x))", + checkparams={"x_1": 5}, ) self.assert_compile( c.any(5, operator=operators.ne), - "%(param_1)s != ANY (x)", - checkparams={"param_1": 5}, + "%(x_1)s != ANY (x)", + checkparams={"x_1": 5}, ) self.assert_compile( postgresql.All(6, c, operator=operators.gt), - "%(param_1)s > ALL (x)", - checkparams={"param_1": 6}, + "%(x_1)s > ALL (x)", + checkparams={"x_1": 6}, ) self.assert_compile( c.all(7, operator=operators.lt), - "%(param_1)s < ALL (x)", - checkparams={"param_1": 7}, + "%(x_1)s < ALL (x)", + checkparams={"x_1": 7}, ) @testing.combinations( diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 8c4bb7fe73..bca952ade0 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1271,16 +1271,16 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): col = column("x", postgresql.ARRAY(Integer)) self.assert_compile( select(col.any(7, operator=operators.lt)), - "SELECT %(param_1)s < ANY (x) AS anon_1", - checkparams={"param_1": 7}, + "SELECT %(x_1)s < ANY (x) AS anon_1", + checkparams={"x_1": 7}, ) def test_array_all(self): col = column("x", postgresql.ARRAY(Integer)) self.assert_compile( select(col.all(7, operator=operators.lt)), - "SELECT %(param_1)s < ALL (x) AS anon_1", - checkparams={"param_1": 7}, + "SELECT %(x_1)s < ALL (x) AS anon_1", + checkparams={"x_1": 7}, ) def test_array_contains(self): @@ -2404,7 +2404,10 @@ class ArrayEnum(fixtures.TestBase): metadata.create_all(connection) connection.execute( tbl.insert(), - [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}], + [ + {"enum_col": ["foo"], "pyenum_col": [MyEnum.a, MyEnum.b]}, + {"enum_col": ["foo", "bar"], "pyenum_col": [MyEnum.b]}, + ], ) return tbl, MyEnum @@ -2422,6 +2425,26 @@ class ArrayEnum(fixtures.TestBase): )(fn) ) + @_enum_combinations + @testing.combinations("all", "any", argnames="fn") + def test_any_all_roundtrip( + self, array_of_enum_fixture, connection, array_cls, enum_cls, fn + ): + """test #6515""" + + tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls) + + if fn == "all": + expr = tbl.c.pyenum_col.all(MyEnum.b) + result = [([MyEnum.b],)] + elif fn == "any": + expr = tbl.c.pyenum_col.any(MyEnum.b) + result = [([MyEnum.a, MyEnum.b],), ([MyEnum.b],)] + else: + assert False + sel = select(tbl.c.pyenum_col).where(expr).order_by(tbl.c.id) + eq_(connection.execute(sel).fetchall(), result) + @_enum_combinations def test_array_of_enums_roundtrip( self, array_of_enum_fixture, connection, array_cls, enum_cls diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 88d1ea0530..77ca95de73 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -3795,8 +3795,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( t.c.arrval.any(5, operator.gt), - ":param_1 > ANY (tab1.arrval)", - checkparams={"param_1": 5}, + ":arrval_1 > ANY (tab1.arrval)", + checkparams={"arrval_1": 5}, ) def test_any_array_comparator_negate_accessor(self, t_fixture): @@ -3804,8 +3804,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( ~t.c.arrval.any(5, operator.gt), - "NOT (:param_1 > ANY (tab1.arrval))", - checkparams={"param_1": 5}, + "NOT (:arrval_1 > ANY (tab1.arrval))", + checkparams={"arrval_1": 5}, ) def test_all_array_comparator_accessor(self, t_fixture): @@ -3813,8 +3813,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( t.c.arrval.all(5, operator.gt), - ":param_1 > ALL (tab1.arrval)", - checkparams={"param_1": 5}, + ":arrval_1 > ALL (tab1.arrval)", + checkparams={"arrval_1": 5}, ) def test_all_array_comparator_negate_accessor(self, t_fixture): @@ -3822,8 +3822,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( ~t.c.arrval.all(5, operator.gt), - "NOT (:param_1 > ALL (tab1.arrval))", - checkparams={"param_1": 5}, + "NOT (:arrval_1 > ALL (tab1.arrval))", + checkparams={"arrval_1": 5}, ) def test_any_array_expression(self, t_fixture):