]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Apply new uniquing rules for future ORM selects
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 21 Jul 2021 19:44:27 +0000 (15:44 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 21 Jul 2021 22:02:02 +0000 (18:02 -0400)
Fixed issue where usage of the :meth:`_result.Result.unique` method with an
ORM result that included column expressions with unhashable types, such as
``JSON`` or ``ARRAY`` using non-tuples would silently fall back to using
the ``id()`` function, rather than raising an error. This now raises an
error when the :meth:`_result.Result.unique` method is used in a 2.0 style
ORM query. Additionally, hashability is assumed to be True for result
values of unknown type, such as often happens when using SQL functions of
unknown return type; if values are truly not hashable then the ``hash()``
itself will raise.

For legacy ORM queries, since the legacy :class:`_orm.Query` object
uniquifies in all cases, the old rules remain in place, which is to use
``id()`` for result values of unknown type as this legacy uniquing is
mostly for the purpose of uniquing ORM entities and not column values.

Fixes: #6769
Change-Id: I5747f706f1e97c78867b5cf28c73360497273808

doc/build/changelog/unreleased_14/6769.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/sqltypes.py
test/orm/test_query.py

diff --git a/doc/build/changelog/unreleased_14/6769.rst b/doc/build/changelog/unreleased_14/6769.rst
new file mode 100644 (file)
index 0000000..05ddc8d
--- /dev/null
@@ -0,0 +1,18 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 6769
+
+    Fixed issue where usage of the :meth:`_result.Result.unique` method with an
+    ORM result that included column expressions with unhashable types, such as
+    ``JSON`` or ``ARRAY`` using non-tuples would silently fall back to using
+    the ``id()`` function, rather than raising an error. This now raises an
+    error when the :meth:`_result.Result.unique` method is used in a 2.0 style
+    ORM query. Additionally, hashability is assumed to be True for result
+    values of unknown type, such as often happens when using SQL functions of
+    unknown return type; if values are truly not hashable then the ``hash()``
+    itself will raise.
+
+    For legacy ORM queries, since the legacy :class:`_orm.Query` object
+    uniquifies in all cases, the old rules remain in place, which is to use
+    ``id()`` for result values of unknown type as this legacy uniquing is
+    mostly for the purpose of uniquing ORM entities and not column values.
\ No newline at end of file
index 0af3fd6afdfe99b7f749e1dedbbe9e8bae372649..c4b695687631dffb85ed1f53ecf07e5080cee853 100644 (file)
@@ -81,6 +81,7 @@ class QueryContext(object):
         _yield_per = None
         _refresh_state = None
         _lazy_loaded_from = None
+        _legacy_uniquing = False
 
     def __init__(
         self,
@@ -2257,6 +2258,10 @@ class _QueryEntity(object):
 
     __slots__ = ()
 
+    _non_hashable_value = False
+    _null_column_type = False
+    use_id_for_hash = False
+
     @classmethod
     def to_compile_state(cls, compile_state, entities, entities_collection):
 
@@ -2387,6 +2392,7 @@ class _MapperEntity(_QueryEntity):
 
     supports_single_entity = True
 
+    _non_hashable_value = True
     use_id_for_hash = True
 
     @property
@@ -2483,7 +2489,6 @@ class _MapperEntity(_QueryEntity):
 
 
 class _BundleEntity(_QueryEntity):
-    use_id_for_hash = False
 
     _extra_entities = ()
 
@@ -2663,9 +2668,13 @@ class _ColumnEntity(_QueryEntity):
         return self.column.type
 
     @property
-    def use_id_for_hash(self):
+    def _non_hashable_value(self):
         return not self.column.type.hashable
 
+    @property
+    def _null_column_type(self):
+        return self.column.type._isnull
+
     def row_processor(self, context, result):
         compile_state = context.compile_state
 
index 948f33ad54b64e6120c7e5011bee11755908d4df..abc8780ed94d382ab620a77481647e096d3b155e 100644 (file)
@@ -92,17 +92,43 @@ def instances(cursor, context):
             "Can't use the ORM yield_per feature in conjunction with unique()"
         )
 
-    row_metadata = SimpleResultMetaData(
-        labels,
-        extra,
-        _unique_filters=[
+    def _not_hashable(datatype):
+        def go(obj):
+            raise sa_exc.InvalidRequestError(
+                "Can't apply uniqueness to row tuple containing value of "
+                "type %r; this datatype produces non-hashable values"
+                % datatype
+            )
+
+        return go
+
+    if context.load_options._legacy_uniquing:
+        unique_filters = [
+            _no_unique
+            if context.yield_per
+            else id
+            if (
+                ent.use_id_for_hash
+                or ent._non_hashable_value
+                or ent._null_column_type
+            )
+            else None
+            for ent in context.compile_state._entities
+        ]
+    else:
+        unique_filters = [
             _no_unique
             if context.yield_per
+            else _not_hashable(ent.column.type)
+            if (not ent.use_id_for_hash and ent._non_hashable_value)
             else id
             if ent.use_id_for_hash
             else None
             for ent in context.compile_state._entities
-        ],
+        ]
+
+    row_metadata = SimpleResultMetaData(
+        labels, extra, _unique_filters=unique_filters
     )
 
     def chunks(size):
index 6ad7f3020e0c918845bac66eabe53767c1d1440b..9a97d37b0913772f23821fe1b79f7185fd0d9819 100644 (file)
@@ -130,7 +130,9 @@ class Query(
 
     _compile_options = ORMCompileState.default_compile_options
 
-    load_options = QueryContext.default_load_options
+    load_options = QueryContext.default_load_options + {
+        "_legacy_uniquing": True
+    }
 
     _params = util.EMPTY_DICT
 
index 1b05465c99742520637ad76a8dafa1bf7a3ba18d..44431d38fce88617f5a8fc0965908f30a5c892e7 100644 (file)
@@ -3184,8 +3184,6 @@ class NullType(TypeEngine):
 
     _isnull = True
 
-    hashable = False
-
     def literal_processor(self, dialect):
         def process(value):
             raise exc.CompileError(
index 3c806e9d5c333dcbf5bdf5d3964b51bd42a9980f..878e914b6e4948e69397a42d920eb414566da6f5 100644 (file)
@@ -79,6 +79,8 @@ from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
+from sqlalchemy.types import NullType
+from sqlalchemy.types import TypeDecorator
 from sqlalchemy.util import collections_abc
 from test.orm import _fixtures
 
@@ -526,10 +528,43 @@ class RowTupleTest(QueryTest):
 
         eq_(q.column_descriptions, asserted)
 
-    def test_unhashable_type(self):
-        from sqlalchemy.types import TypeDecorator, Integer
-        from sqlalchemy.sql import type_coerce
+    def test_unhashable_type_legacy(self):
+        class MyType(TypeDecorator):
+            impl = Integer
+            hashable = False
+            cache_ok = True
+
+            def process_result_value(self, value, dialect):
+                return [value]
+
+        User, users = self.classes.User, self.tables.users
+        Address, addresses = self.classes.Address, self.tables.addresses
+        mapper(User, users, properties={"addresses": relationship(Address)})
+        mapper(Address, addresses)
+
+        s = fixture_session()
+        q = (
+            s.query(User, type_coerce(users.c.id, MyType).label("foo"))
+            .filter(User.id.in_([7, 8]))
+            .join(User.addresses)
+            .order_by(User.id)
+        )
+
+        result = q.all()
 
+        # uniquing basically does not occur because we can't hash on
+        # MyType
+        eq_(
+            result,
+            [
+                (User(id=7), [7]),
+                (User(id=8), [8]),
+                (User(id=8), [8]),
+                (User(id=8), [8]),
+            ],
+        )
+
+    def test_unhashable_type_future(self):
         class MyType(TypeDecorator):
             impl = Integer
             hashable = False
@@ -539,15 +574,118 @@ class RowTupleTest(QueryTest):
                 return [value]
 
         User, users = self.classes.User, self.tables.users
+        Address, addresses = self.classes.Address, self.tables.addresses
+        mapper(User, users, properties={"addresses": relationship(Address)})
+        mapper(Address, addresses)
+
+        s = fixture_session()
+
+        stmt = (
+            select(User, type_coerce(users.c.id, MyType).label("foo"))
+            .filter(User.id.in_([7, 8]))
+            .join(User.addresses)
+            .order_by(User.id)
+        )
+
+        result = s.execute(stmt).unique()
+
+        with expect_raises_message(
+            sa_exc.InvalidRequestError,
+            r"Can't apply uniqueness to row tuple "
+            r"containing value of type MyType\(\)",
+        ):
+            result.all()
+
+    def test_unknown_type_assume_not_hashable_legacy(self):
+        User, users = self.classes.User, self.tables.users
 
-        mapper(User, users)
+        User, users = self.classes.User, self.tables.users
+        Address, addresses = self.classes.Address, self.tables.addresses
+        mapper(User, users, properties={"addresses": relationship(Address)})
+        mapper(Address, addresses)
 
         s = fixture_session()
-        q = s.query(User, type_coerce(users.c.id, MyType).label("foo")).filter(
-            User.id == 7
+
+        q = (
+            s.query(
+                User, type_coerce("Some Unique String", NullType).label("foo")
+            )
+            .filter(User.id.in_([7, 8]))
+            .join(User.addresses)
+            .order_by(User.id)
         )
-        row = q.first()
-        eq_(row, (User(id=7), [7]))
+
+        result = q.all()
+
+        eq_(
+            result,
+            [
+                (User(id=7, name="jack"), "Some Unique String"),
+                (User(id=8, name="ed"), "Some Unique String"),
+                (User(id=8, name="ed"), "Some Unique String"),
+                (User(id=8, name="ed"), "Some Unique String"),
+            ],
+        )
+
+    def test_unknown_type_assume_hashable_future(self):
+        User, users = self.classes.User, self.tables.users
+
+        User, users = self.classes.User, self.tables.users
+        Address, addresses = self.classes.Address, self.tables.addresses
+        mapper(User, users, properties={"addresses": relationship(Address)})
+        mapper(Address, addresses)
+
+        s = fixture_session()
+
+        # TODO: it's also unusual I need a label() for type_coerce
+        stmt = (
+            select(
+                User, type_coerce("Some Unique String", NullType).label("foo")
+            )
+            .filter(User.id.in_([7, 8]))
+            .join(User.addresses)
+            .order_by(User.id)
+        )
+
+        result = s.execute(stmt).unique()
+
+        eq_(
+            result.all(),
+            [
+                (User(id=7, name="jack"), "Some Unique String"),
+                (User(id=8, name="ed"), "Some Unique String"),
+            ],
+        )
+
+    def test_unknown_type_truly_not_hashable_future(self):
+        User, users = self.classes.User, self.tables.users
+
+        User, users = self.classes.User, self.tables.users
+        Address, addresses = self.classes.Address, self.tables.addresses
+        mapper(User, users, properties={"addresses": relationship(Address)})
+        mapper(Address, addresses)
+
+        class MyType(TypeDecorator):
+            impl = Integer
+            hashable = True  # which is wrong
+            cache_ok = True
+
+            def process_result_value(self, value, dialect):
+                return [value]
+
+        s = fixture_session()
+
+        stmt = (
+            select(User, type_coerce(User.id, MyType).label("foo"))
+            .filter(User.id.in_([7, 8]))
+            .join(User.addresses)
+            .order_by(User.id)
+        )
+
+        result = s.execute(stmt).unique()
+
+        with expect_raises_message(TypeError, "unhashable type"):
+            result.all()
 
 
 class RowLabelingTest(QueryTest):