]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
properly type array element in any() / all()
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 22 Apr 2022 14:57:00 +0000 (10:57 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 22 Apr 2022 16:53:27 +0000 (12:53 -0400)
Fixed bug in :class:`.ARRAY` datatype in combination with :class:`.Enum` on
PostgreSQL where using the ``.any()`` method to render SQL ANY(), given
members of the Python enumeration as arguments, would produce a type
adaptation failure on all drivers.

Fixes: #6515
Change-Id: Ia1e3b4e10aaf264ed436ce6030d105fc60023433
(cherry picked from commit d023c8e1c7ad82fb249fab5155eb83dee17a160c)

doc/build/changelog/unreleased_14/6515.rst [new file with mode: 0644]
lib/sqlalchemy/sql/sqltypes.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_types.py
test/sql/test_operators.py

diff --git a/doc/build/changelog/unreleased_14/6515.rst b/doc/build/changelog/unreleased_14/6515.rst
new file mode 100644 (file)
index 0000000..0ac5332
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 6515
+
+    Fixed bug in :class:`.ARRAY` datatype in combination with :class:`.Enum` on
+    PostgreSQL where using the ``.any()`` method to render SQL ANY(), given
+    members of the Python enumeration as arguments, would produce a type
+    adaptation failure on all drivers.
index e51397da7f0cfd41b921c08d58f668dfd87eaf31..92aaf1c57dc07304f0908a25c2dc21452cf8565e 100644 (file)
@@ -2851,10 +2851,18 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
             elements = util.preloaded.sql_elements
             operator = operator if operator else operators.eq
 
+            arr_type = self.type
+
             # send plain BinaryExpression so that negate remains at None,
             # leading to NOT expr for negation.
             return elements.BinaryExpression(
-                coercions.expect(roles.ExpressionElementRole, other),
+                coercions.expect(
+                    roles.BinaryElementRole,
+                    element=other,
+                    operator=operator,
+                    expr=self.expr,
+                    bindparam_type=arr_type.item_type,
+                ),
                 elements.CollectionAggregate._create_any(self.expr),
                 operator,
             )
@@ -2895,10 +2903,18 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
             elements = util.preloaded.sql_elements
             operator = operator if operator else operators.eq
 
+            arr_type = self.type
+
             # send plain BinaryExpression so that negate remains at None,
             # leading to NOT expr for negation.
             return elements.BinaryExpression(
-                coercions.expect(roles.ExpressionElementRole, other),
+                coercions.expect(
+                    roles.BinaryElementRole,
+                    element=other,
+                    operator=operator,
+                    expr=self.expr,
+                    bindparam_type=arr_type.item_type,
+                ),
                 elements.CollectionAggregate._create_all(self.expr),
                 operator,
             )
index 49ab15261e60593671768d99c32f37167d071ec6..6bd2f2fa2be05901f83c0b96ea59cad02c754b58 100644 (file)
@@ -1523,48 +1523,48 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         )
         self.assert_compile(
             postgresql.Any(4, c),
-            "%(param_1)s = ANY (x)",
-            checkparams={"param_1": 4},
+            "%(x_1)s = ANY (x)",
+            checkparams={"x_1": 4},
         )
 
         self.assert_compile(
             c.any(5),
-            "%(param_1)s = ANY (x)",
-            checkparams={"param_1": 5},
+            "%(x_1)s = ANY (x)",
+            checkparams={"x_1": 5},
         )
 
         self.assert_compile(
             ~c.any(5),
-            "NOT (%(param_1)s = ANY (x))",
-            checkparams={"param_1": 5},
+            "NOT (%(x_1)s = ANY (x))",
+            checkparams={"x_1": 5},
         )
 
         self.assert_compile(
             c.all(5),
-            "%(param_1)s = ALL (x)",
-            checkparams={"param_1": 5},
+            "%(x_1)s = ALL (x)",
+            checkparams={"x_1": 5},
         )
 
         self.assert_compile(
             ~c.all(5),
-            "NOT (%(param_1)s = ALL (x))",
-            checkparams={"param_1": 5},
+            "NOT (%(x_1)s = ALL (x))",
+            checkparams={"x_1": 5},
         )
 
         self.assert_compile(
             c.any(5, operator=operators.ne),
-            "%(param_1)s != ANY (x)",
-            checkparams={"param_1": 5},
+            "%(x_1)s != ANY (x)",
+            checkparams={"x_1": 5},
         )
         self.assert_compile(
             postgresql.All(6, c, operator=operators.gt),
-            "%(param_1)s > ALL (x)",
-            checkparams={"param_1": 6},
+            "%(x_1)s > ALL (x)",
+            checkparams={"x_1": 6},
         )
         self.assert_compile(
             c.all(7, operator=operators.lt),
-            "%(param_1)s < ALL (x)",
-            checkparams={"param_1": 7},
+            "%(x_1)s < ALL (x)",
+            checkparams={"x_1": 7},
         )
 
     @testing.combinations(
index fe39672627090015b01b37d6a76db615d9e7503b..ad0fcfeeea3164d6b2fd203e7a2df6f38a33bc6a 100644 (file)
@@ -1261,16 +1261,16 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
         col = column("x", postgresql.ARRAY(Integer))
         self.assert_compile(
             select(col.any(7, operator=operators.lt)),
-            "SELECT %(param_1)s < ANY (x) AS anon_1",
-            checkparams={"param_1": 7},
+            "SELECT %(x_1)s < ANY (x) AS anon_1",
+            checkparams={"x_1": 7},
         )
 
     def test_array_all(self):
         col = column("x", postgresql.ARRAY(Integer))
         self.assert_compile(
             select(col.all(7, operator=operators.lt)),
-            "SELECT %(param_1)s < ALL (x) AS anon_1",
-            checkparams={"param_1": 7},
+            "SELECT %(x_1)s < ALL (x) AS anon_1",
+            checkparams={"x_1": 7},
         )
 
     def test_array_contains(self):
@@ -2397,14 +2397,19 @@ class ArrayEnum(fixtures.TestBase):
                         array_cls(enum_cls(MyEnum)),
                     ),
                 )
+                data = [
+                    {"enum_col": ["foo"], "pyenum_col": [MyEnum.a, MyEnum.b]},
+                    {"enum_col": ["foo", "bar"], "pyenum_col": [MyEnum.b]},
+                ]
             else:
                 MyEnum = None
+                data = [
+                    {"enum_col": ["foo"]},
+                    {"enum_col": ["foo", "bar"]},
+                ]
 
             metadata.create_all(connection)
-            connection.execute(
-                tbl.insert(),
-                [{"enum_col": ["foo"]}, {"enum_col": ["foo", "bar"]}],
-            )
+            connection.execute(tbl.insert(), data)
             return tbl, MyEnum
 
         yield go
@@ -2421,6 +2426,27 @@ class ArrayEnum(fixtures.TestBase):
             )(fn)
         )
 
+    @testing.requires.python3
+    @_enum_combinations
+    @testing.combinations("all", "any", argnames="fn")
+    def test_any_all_roundtrip(
+        self, array_of_enum_fixture, connection, array_cls, enum_cls, fn
+    ):
+        """test #6515"""
+
+        tbl, MyEnum = array_of_enum_fixture(array_cls, enum_cls)
+
+        if fn == "all":
+            expr = tbl.c.pyenum_col.all(MyEnum.b)
+            result = [([MyEnum.b],)]
+        elif fn == "any":
+            expr = tbl.c.pyenum_col.any(MyEnum.b)
+            result = [([MyEnum.a, MyEnum.b],), ([MyEnum.b],)]
+        else:
+            assert False
+        sel = select(tbl.c.pyenum_col).where(expr).order_by(tbl.c.id)
+        eq_(connection.execute(sel).fetchall(), result)
+
     @_enum_combinations
     def test_array_of_enums_roundtrip(
         self, array_of_enum_fixture, connection, array_cls, enum_cls
index 4eff872f4f3be11b59e382923da5c60f6b2cbfbf..c524b0aeaa9bc5578cc0ee9651baef440ed54931 100644 (file)
@@ -3544,8 +3544,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         self.assert_compile(
             t.c.arrval.any(5, operator.gt),
-            ":param_1 > ANY (tab1.arrval)",
-            checkparams={"param_1": 5},
+            ":arrval_1 > ANY (tab1.arrval)",
+            checkparams={"arrval_1": 5},
         )
 
     def test_any_array_comparator_negate_accessor(self, t_fixture):
@@ -3553,8 +3553,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         self.assert_compile(
             ~t.c.arrval.any(5, operator.gt),
-            "NOT (:param_1 > ANY (tab1.arrval))",
-            checkparams={"param_1": 5},
+            "NOT (:arrval_1 > ANY (tab1.arrval))",
+            checkparams={"arrval_1": 5},
         )
 
     def test_all_array_comparator_accessor(self, t_fixture):
@@ -3562,8 +3562,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         self.assert_compile(
             t.c.arrval.all(5, operator.gt),
-            ":param_1 > ALL (tab1.arrval)",
-            checkparams={"param_1": 5},
+            ":arrval_1 > ALL (tab1.arrval)",
+            checkparams={"arrval_1": 5},
         )
 
     def test_all_array_comparator_negate_accessor(self, t_fixture):
@@ -3571,8 +3571,8 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         self.assert_compile(
             ~t.c.arrval.all(5, operator.gt),
-            "NOT (:param_1 > ALL (tab1.arrval))",
-            checkparams={"param_1": 5},
+            "NOT (:arrval_1 > ALL (tab1.arrval))",
+            checkparams={"arrval_1": 5},
         )
 
     def test_any_array_expression(self, t_fixture):