]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
check for TypeDecorator when handling getitem
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Jul 2022 14:50:05 +0000 (10:50 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Jul 2022 16:37:26 +0000 (12:37 -0400)
Fixed issue where :class:`.TypeDecorator` would not correctly proxy the
``__getitem__()`` operator when decorating the :class:`.ARRAY` datatype,
without explicit workarounds.

Fixes: #7249
Change-Id: I3273572b4757e41fb5952639cb867314227d368a

doc/build/changelog/unreleased_14/7249.rst [new file with mode: 0644]
lib/sqlalchemy/sql/default_comparator.py
lib/sqlalchemy/sql/type_api.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_14/7249.rst b/doc/build/changelog/unreleased_14/7249.rst
new file mode 100644 (file)
index 0000000..5d0cb65
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, types
+    :tickets: 7249
+
+    Fixed issue where :class:`.TypeDecorator` would not correctly proxy the
+    ``__getitem__()`` operator when decorating the :class:`.ARRAY` datatype,
+    without explicit workarounds.
index 512fca8d0939c6894b1919b97fc8802ec7a2b5ce..619be2cd1ee1cd1b48e4fa3646fe246277800485 100644 (file)
@@ -232,7 +232,11 @@ def _in_impl(
 def _getitem_impl(
     expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
 ) -> ColumnElement[Any]:
-    if isinstance(expr.type, type_api.INDEXABLE):
+    if (
+        isinstance(expr.type, type_api.INDEXABLE)
+        or isinstance(expr.type, type_api.TypeDecorator)
+        and isinstance(expr.type.impl_instance, type_api.INDEXABLE)
+    ):
         other = coercions.expect(
             roles.BinaryElementRole, other, expr=expr, operator=op
         )
index efaf5d2a79201a04b39358be3a0263c996a44245..6c1a99daae9191f9b90970868c25c32bf336475b 100644 (file)
@@ -1551,8 +1551,11 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
        the default rules of :meth:`.TypeEngine.coerce_compared_value` should
        be used in order to deal with operators like index operations::
 
+            from sqlalchemy import JSON
+            from sqlalchemy import TypeDecorator
+
             class MyJsonType(TypeDecorator):
-                impl = postgresql.JSON
+                impl = JSON
 
                 cache_ok = True
 
@@ -1562,6 +1565,24 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
        Without the above step, index operations such as ``mycol['foo']``
        will cause the index value ``'foo'`` to be JSON encoded.
 
+       Similarly, when working with the :class:`.ARRAY` datatype, the
+       type coercion for index operations (e.g. ``mycol[5]``) is also
+       handled by :meth:`.TypeDecorator.coerce_compared_value`, where
+       again a simple override is sufficient unless special rules are needed
+       for particular operators::
+
+            from sqlalchemy import ARRAY
+            from sqlalchemy import TypeDecorator
+
+            class MyArrayType(TypeDecorator):
+                impl = ARRAY
+
+                cache_ok = True
+
+                def coerce_compared_value(self, op, value):
+                    return self.impl.coerce_compared_value(op, value)
+
+
     """
 
     __visit_name__ = "type_decorator"
index a154666bb678bc6bcb9c6032e0c94b809d916bc7..a608d0040b6697d30f1a2e742eea1b4ee917b55a 100644 (file)
@@ -782,6 +782,136 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
         eq_(result.fetchall(), [(3, 1500), (4, 900)])
 
 
+class TypeDecoratorSpecialCasesTest(AssertsCompiledSQL, fixtures.TestBase):
+    __backend__ = True
+
+    @testing.requires.array_type
+    def test_typedec_of_array_modified(self, metadata, connection):
+        """test #7249"""
+
+        class SkipsFirst(TypeDecorator):  # , Indexable):
+            impl = ARRAY(Integer, zero_indexes=True)
+
+            cache_ok = True
+
+            def process_bind_param(self, value, dialect):
+                return value[1:]
+
+            def copy(self, **kw):
+                return SkipsFirst(**kw)
+
+            def coerce_compared_value(self, op, value):
+                return self.impl.coerce_compared_value(op, value)
+
+        t = Table(
+            "t",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", SkipsFirst),
+        )
+        t.create(connection)
+
+        connection.execute(t.insert(), {"data": [1, 2, 3]})
+        val = connection.scalar(select(t.c.data))
+        eq_(val, [2, 3])
+
+        val = connection.scalar(select(t.c.data[0]))
+        eq_(val, 2)
+
+    def test_typedec_of_array_ops(self):
+        class ArrayDec(TypeDecorator):
+            impl = ARRAY(Integer, zero_indexes=True)
+
+            cache_ok = True
+
+            def coerce_compared_value(self, op, value):
+                return self.impl.coerce_compared_value(op, value)
+
+        expr1 = column("q", ArrayDec)[0]
+        expr2 = column("q", ARRAY(Integer, zero_indexes=True))[0]
+
+        eq_(expr1.right.type._type_affinity, Integer)
+        eq_(expr2.right.type._type_affinity, Integer)
+
+        self.assert_compile(
+            column("q", ArrayDec).any(7, operator=operators.lt),
+            "%(q_1)s < ANY (q)",
+            dialect="postgresql",
+        )
+
+        self.assert_compile(
+            column("q", ArrayDec)[5], "q[%(q_1)s]", dialect="postgresql"
+        )
+
+    def test_typedec_of_json_ops(self):
+        class JsonDec(TypeDecorator):
+            impl = JSON()
+
+            cache_ok = True
+
+        self.assert_compile(
+            column("q", JsonDec)["q"], "q -> %(q_1)s", dialect="postgresql"
+        )
+
+        self.assert_compile(
+            column("q", JsonDec)["q"].as_integer(),
+            "CAST(q ->> %(q_1)s AS INTEGER)",
+            dialect="postgresql",
+        )
+
+    @testing.requires.array_type
+    def test_typedec_of_array(self, metadata, connection):
+        """test #7249"""
+
+        class ArrayDec(TypeDecorator):
+            impl = ARRAY(Integer, zero_indexes=True)
+
+            cache_ok = True
+
+            def coerce_compared_value(self, op, value):
+                return self.impl.coerce_compared_value(op, value)
+
+        t = Table(
+            "t",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", ArrayDec),
+        )
+
+        t.create(connection)
+
+        connection.execute(t.insert(), {"data": [1, 2, 3]})
+        val = connection.scalar(select(t.c.data))
+        eq_(val, [1, 2, 3])
+
+        val = connection.scalar(select(t.c.data[0]))
+        eq_(val, 1)
+
+    @testing.requires.json_type
+    def test_typedec_of_json(self, metadata, connection):
+        """test #7249"""
+
+        class JsonDec(TypeDecorator):
+            impl = JSON()
+
+            cache_ok = True
+
+        t = Table(
+            "t",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", JsonDec),
+        )
+        t.create(connection)
+
+        connection.execute(t.insert(), {"data": {"key": "value"}})
+        val = connection.scalar(select(t.c.data))
+        eq_(val, {"key": "value"})
+
+        val = connection.scalar(select(t.c.data["key"].as_string()))
+        eq_(val, "value")
+
+
 class BindProcessorInsertValuesTest(UserDefinedRoundTripTest):
     """related to #6770, test that insert().values() applies to
     bound parameter handlers including the None value."""