From: Mike Bayer Date: Tue, 19 Jul 2022 14:50:05 +0000 (-0400) Subject: check for TypeDecorator when handling getitem X-Git-Tag: rel_1_4_40~22^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=81d20ae7f23c9e3f3487dc91687f934ee6ae124c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git check for TypeDecorator when handling getitem Fixed issue where :class:`.TypeDecorator` would not correctly proxy the ``__getitem__()`` operator when decorating the :class:`.ARRAY` datatype, without explicit workarounds. Fixes: #7249 Change-Id: I3273572b4757e41fb5952639cb867314227d368a (cherry picked from commit 1e01fab7e600c53284eabceceab5706e4074eb2e) --- diff --git a/doc/build/changelog/unreleased_14/7249.rst b/doc/build/changelog/unreleased_14/7249.rst new file mode 100644 index 0000000000..5d0cb65818 --- /dev/null +++ b/doc/build/changelog/unreleased_14/7249.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, types + :tickets: 7249 + + Fixed issue where :class:`.TypeDecorator` would not correctly proxy the + ``__getitem__()`` operator when decorating the :class:`.ARRAY` datatype, + without explicit workarounds. diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 7d2f1dd2a4..70586c696f 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -168,7 +168,11 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): def _getitem_impl(expr, op, other, **kw): - if isinstance(expr.type, type_api.INDEXABLE): + if ( + isinstance(expr.type, type_api.INDEXABLE) + or isinstance(expr.type, type_api.TypeDecorator) + and isinstance(expr.type.impl, type_api.INDEXABLE) + ): other = coercions.expect( roles.BinaryElementRole, other, expr=expr, operator=op ) diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 7431c08a41..172ce0d884 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -1265,8 +1265,11 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine): the default rules of :meth:`.TypeEngine.coerce_compared_value` should be used in order to deal with operators like index operations:: + from sqlalchemy import JSON + from sqlalchemy import TypeDecorator + class MyJsonType(TypeDecorator): - impl = postgresql.JSON + impl = JSON cache_ok = True @@ -1276,6 +1279,24 @@ class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine): Without the above step, index operations such as ``mycol['foo']`` will cause the index value ``'foo'`` to be JSON encoded. + Similarly, when working with the :class:`.ARRAY` datatype, the + type coercion for index operations (e.g. ``mycol[5]``) is also + handled by :meth:`.TypeDecorator.coerce_compared_value`, where + again a simple override is sufficient unless special rules are needed + for particular operators:: + + from sqlalchemy import ARRAY + from sqlalchemy import TypeDecorator + + class MyArrayType(TypeDecorator): + impl = ARRAY + + cache_ok = True + + def coerce_compared_value(self, op, value): + return self.impl.coerce_compared_value(op, value) + + """ __visit_name__ = "type_decorator" diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 12932a1c9c..c4f2f27260 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -784,6 +784,136 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): eq_(result.fetchall(), [(3, 1500), (4, 900)]) +class TypeDecoratorSpecialCasesTest(AssertsCompiledSQL, fixtures.TestBase): + __backend__ = True + + @testing.requires.array_type + def test_typedec_of_array_modified(self, metadata, connection): + """test #7249""" + + class SkipsFirst(TypeDecorator): # , Indexable): + impl = ARRAY(Integer, zero_indexes=True) + + cache_ok = True + + def process_bind_param(self, value, dialect): + return value[1:] + + def copy(self, **kw): + return SkipsFirst(**kw) + + def coerce_compared_value(self, op, value): + return self.impl.coerce_compared_value(op, value) + + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("data", SkipsFirst), + ) + t.create(connection) + + connection.execute(t.insert(), {"data": [1, 2, 3]}) + val = connection.scalar(select(t.c.data)) + eq_(val, [2, 3]) + + val = connection.scalar(select(t.c.data[0])) + eq_(val, 2) + + def test_typedec_of_array_ops(self): + class ArrayDec(TypeDecorator): + impl = ARRAY(Integer, zero_indexes=True) + + cache_ok = True + + def coerce_compared_value(self, op, value): + return self.impl.coerce_compared_value(op, value) + + expr1 = column("q", ArrayDec)[0] + expr2 = column("q", ARRAY(Integer, zero_indexes=True))[0] + + eq_(expr1.right.type._type_affinity, Integer) + eq_(expr2.right.type._type_affinity, Integer) + + self.assert_compile( + column("q", ArrayDec).any(7, operator=operators.lt), + "%(q_1)s < ANY (q)", + dialect="postgresql", + ) + + self.assert_compile( + column("q", ArrayDec)[5], "q[%(q_1)s]", dialect="postgresql" + ) + + def test_typedec_of_json_ops(self): + class JsonDec(TypeDecorator): + impl = JSON() + + cache_ok = True + + self.assert_compile( + column("q", JsonDec)["q"], "q -> %(q_1)s", dialect="postgresql" + ) + + self.assert_compile( + column("q", JsonDec)["q"].as_integer(), + "CAST(q ->> %(q_1)s AS INTEGER)", + dialect="postgresql", + ) + + @testing.requires.array_type + def test_typedec_of_array(self, metadata, connection): + """test #7249""" + + class ArrayDec(TypeDecorator): + impl = ARRAY(Integer, zero_indexes=True) + + cache_ok = True + + def coerce_compared_value(self, op, value): + return self.impl.coerce_compared_value(op, value) + + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("data", ArrayDec), + ) + + t.create(connection) + + connection.execute(t.insert(), {"data": [1, 2, 3]}) + val = connection.scalar(select(t.c.data)) + eq_(val, [1, 2, 3]) + + val = connection.scalar(select(t.c.data[0])) + eq_(val, 1) + + @testing.requires.json_type + def test_typedec_of_json(self, metadata, connection): + """test #7249""" + + class JsonDec(TypeDecorator): + impl = JSON() + + cache_ok = True + + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("data", JsonDec), + ) + t.create(connection) + + connection.execute(t.insert(), {"data": {"key": "value"}}) + val = connection.scalar(select(t.c.data)) + eq_(val, {"key": "value"}) + + val = connection.scalar(select(t.c.data["key"].as_string())) + eq_(val, "value") + + class BindProcessorInsertValuesTest(UserDefinedRoundTripTest): """related to #6770, test that insert().values() applies to bound parameter handlers including the None value."""