]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
check for TypeDecorator when handling getitem
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Jul 2022 14:50:05 +0000 (10:50 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Jul 2022 16:37:57 +0000 (12:37 -0400)
Fixed issue where :class:`.TypeDecorator` would not correctly proxy the
``__getitem__()`` operator when decorating the :class:`.ARRAY` datatype,
without explicit workarounds.

Fixes: #7249
Change-Id: I3273572b4757e41fb5952639cb867314227d368a
(cherry picked from commit 1e01fab7e600c53284eabceceab5706e4074eb2e)

doc/build/changelog/unreleased_14/7249.rst [new file with mode: 0644]
lib/sqlalchemy/sql/default_comparator.py
lib/sqlalchemy/sql/type_api.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_14/7249.rst b/doc/build/changelog/unreleased_14/7249.rst
new file mode 100644 (file)
index 0000000..5d0cb65
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, types
+    :tickets: 7249
+
+    Fixed issue where :class:`.TypeDecorator` would not correctly proxy the
+    ``__getitem__()`` operator when decorating the :class:`.ARRAY` datatype,
+    without explicit workarounds.
index 7d2f1dd2a4a3e5010ab7104e3c46ac6891e97459..70586c696f0c885f8d4293d84abe57ff81155e7d 100644 (file)
@@ -168,7 +168,11 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
 
 
 def _getitem_impl(expr, op, other, **kw):
-    if isinstance(expr.type, type_api.INDEXABLE):
+    if (
+        isinstance(expr.type, type_api.INDEXABLE)
+        or isinstance(expr.type, type_api.TypeDecorator)
+        and isinstance(expr.type.impl, type_api.INDEXABLE)
+    ):
         other = coercions.expect(
             roles.BinaryElementRole, other, expr=expr, operator=op
         )
index 7431c08a41d2f6b5ff76805dc6b7c1bd37a92df0..172ce0d884ec9d0345051c58ebb59a0d704e33dd 100644 (file)
@@ -1265,8 +1265,11 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine):
        the default rules of :meth:`.TypeEngine.coerce_compared_value` should
        be used in order to deal with operators like index operations::
 
+            from sqlalchemy import JSON
+            from sqlalchemy import TypeDecorator
+
             class MyJsonType(TypeDecorator):
-                impl = postgresql.JSON
+                impl = JSON
 
                 cache_ok = True
 
@@ -1276,6 +1279,24 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine):
        Without the above step, index operations such as ``mycol['foo']``
        will cause the index value ``'foo'`` to be JSON encoded.
 
+       Similarly, when working with the :class:`.ARRAY` datatype, the
+       type coercion for index operations (e.g. ``mycol[5]``) is also
+       handled by :meth:`.TypeDecorator.coerce_compared_value`, where
+       again a simple override is sufficient unless special rules are needed
+       for particular operators::
+
+            from sqlalchemy import ARRAY
+            from sqlalchemy import TypeDecorator
+
+            class MyArrayType(TypeDecorator):
+                impl = ARRAY
+
+                cache_ok = True
+
+                def coerce_compared_value(self, op, value):
+                    return self.impl.coerce_compared_value(op, value)
+
+
     """
 
     __visit_name__ = "type_decorator"
index 12932a1c9c789650dcca12f754e233df49eab3d7..c4f2f27260a769d8589813e0fa2150fba8e6d8eb 100644 (file)
@@ -784,6 +784,136 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
         eq_(result.fetchall(), [(3, 1500), (4, 900)])
 
 
+class TypeDecoratorSpecialCasesTest(AssertsCompiledSQL, fixtures.TestBase):
+    __backend__ = True
+
+    @testing.requires.array_type
+    def test_typedec_of_array_modified(self, metadata, connection):
+        """test #7249"""
+
+        class SkipsFirst(TypeDecorator):  # , Indexable):
+            impl = ARRAY(Integer, zero_indexes=True)
+
+            cache_ok = True
+
+            def process_bind_param(self, value, dialect):
+                return value[1:]
+
+            def copy(self, **kw):
+                return SkipsFirst(**kw)
+
+            def coerce_compared_value(self, op, value):
+                return self.impl.coerce_compared_value(op, value)
+
+        t = Table(
+            "t",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", SkipsFirst),
+        )
+        t.create(connection)
+
+        connection.execute(t.insert(), {"data": [1, 2, 3]})
+        val = connection.scalar(select(t.c.data))
+        eq_(val, [2, 3])
+
+        val = connection.scalar(select(t.c.data[0]))
+        eq_(val, 2)
+
+    def test_typedec_of_array_ops(self):
+        class ArrayDec(TypeDecorator):
+            impl = ARRAY(Integer, zero_indexes=True)
+
+            cache_ok = True
+
+            def coerce_compared_value(self, op, value):
+                return self.impl.coerce_compared_value(op, value)
+
+        expr1 = column("q", ArrayDec)[0]
+        expr2 = column("q", ARRAY(Integer, zero_indexes=True))[0]
+
+        eq_(expr1.right.type._type_affinity, Integer)
+        eq_(expr2.right.type._type_affinity, Integer)
+
+        self.assert_compile(
+            column("q", ArrayDec).any(7, operator=operators.lt),
+            "%(q_1)s < ANY (q)",
+            dialect="postgresql",
+        )
+
+        self.assert_compile(
+            column("q", ArrayDec)[5], "q[%(q_1)s]", dialect="postgresql"
+        )
+
+    def test_typedec_of_json_ops(self):
+        class JsonDec(TypeDecorator):
+            impl = JSON()
+
+            cache_ok = True
+
+        self.assert_compile(
+            column("q", JsonDec)["q"], "q -> %(q_1)s", dialect="postgresql"
+        )
+
+        self.assert_compile(
+            column("q", JsonDec)["q"].as_integer(),
+            "CAST(q ->> %(q_1)s AS INTEGER)",
+            dialect="postgresql",
+        )
+
+    @testing.requires.array_type
+    def test_typedec_of_array(self, metadata, connection):
+        """test #7249"""
+
+        class ArrayDec(TypeDecorator):
+            impl = ARRAY(Integer, zero_indexes=True)
+
+            cache_ok = True
+
+            def coerce_compared_value(self, op, value):
+                return self.impl.coerce_compared_value(op, value)
+
+        t = Table(
+            "t",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", ArrayDec),
+        )
+
+        t.create(connection)
+
+        connection.execute(t.insert(), {"data": [1, 2, 3]})
+        val = connection.scalar(select(t.c.data))
+        eq_(val, [1, 2, 3])
+
+        val = connection.scalar(select(t.c.data[0]))
+        eq_(val, 1)
+
+    @testing.requires.json_type
+    def test_typedec_of_json(self, metadata, connection):
+        """test #7249"""
+
+        class JsonDec(TypeDecorator):
+            impl = JSON()
+
+            cache_ok = True
+
+        t = Table(
+            "t",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", JsonDec),
+        )
+        t.create(connection)
+
+        connection.execute(t.insert(), {"data": {"key": "value"}})
+        val = connection.scalar(select(t.c.data))
+        eq_(val, {"key": "value"})
+
+        val = connection.scalar(select(t.c.data["key"].as_string()))
+        eq_(val, "value")
+
+
 class BindProcessorInsertValuesTest(UserDefinedRoundTripTest):
     """related to #6770, test that insert().values() applies to
     bound parameter handlers including the None value."""