]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement primary key custom sorting. 4816/head
authorNicolas CANIART <nicolas@caniart.net>
Thu, 22 Aug 2019 09:44:48 +0000 (11:44 +0200)
committerNicolas CANIART <nicolas@caniart.net>
Thu, 22 Aug 2019 17:07:12 +0000 (19:07 +0200)
Fixes: #4285
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
test/orm/test_mapper.py
test/orm/test_naturalpks.py
test/orm/test_unitofwork.py
test/sql/test_types.py

index 5e8d25647b2d0e503bf00220817b66c55b26b377..07fd9f3fb627ed335ca270c53bedbe7c7f99965b 100644 (file)
@@ -2748,6 +2748,25 @@ class Mapper(InspectionAttr):
         )
         return identity_key[1]
 
+    @_memoized_configured_property
+    def _persistent_sortkey_fn(self):
+        key_fns = [col.type.sort_key_function for col in self.primary_key]
+
+        if set(key_fns).difference([None]):
+
+            def key(state):
+                return tuple(
+                    key_fn(val) if key_fn is not None else val
+                    for key_fn, val in zip(key_fns, state.key[1])
+                )
+
+        else:
+
+            def key(state):
+                return state.key[1]
+
+        return key
+
     @_memoized_configured_property
     def _identity_key_props(self):
         return [self._columntoproperty[col] for col in self.primary_key]
index fb25d2405fcd3b09d7f82fa2b5ea0ac6dc2c6821..68052dfdd1f54abca0530d5f66f50939a8db9043 100644 (file)
@@ -196,7 +196,7 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
 
     # if batch=false, call _save_obj separately for each object
     if not single and not base_mapper.batch:
-        for state in _sort_states(states):
+        for state in _sort_states(base_mapper, states):
             save_obj(base_mapper, [state], uowtransaction, single=True)
         return
 
@@ -1607,7 +1607,7 @@ def _connections_for_states(base_mapper, uowtransaction, states):
         connection = uowtransaction.transaction.connection(base_mapper)
         connection_callable = None
 
-    for state in _sort_states(states):
+    for state in _sort_states(base_mapper, states):
         if connection_callable:
             connection = connection_callable(base_mapper, state.obj())
 
@@ -1625,12 +1625,15 @@ def _cached_connection_dict(base_mapper):
     )
 
 
-def _sort_states(states):
+def _sort_states(mapper, states):
     pending = set(states)
     persistent = set(s for s in pending if s.key is not None)
     pending.difference_update(persistent)
+
     try:
-        persistent_sorted = sorted(persistent, key=lambda q: q.key[1])
+        persistent_sorted = sorted(
+            persistent, key=mapper._persistent_sortkey_fn
+        )
     except TypeError as err:
         raise sa_exc.InvalidRequestError(
             "Could not sort objects by primary key; primary key "
index 631352ceb7986d05de70c34cd8df7354fc34ff48..4c1d86924bde73081937d51f504344eebb7da864 100644 (file)
@@ -1356,6 +1356,18 @@ class Enum(Emulated, String, SchemaType):
 
            .. versionadded:: 1.2.3
 
+        :param sort_key_function: a Python callable which may be used as the
+           "key" argument in the Python sorted() built in.   Python
+           Enum objects are not inherently sortable; for projects that want
+           to use Enum objects as the type for primary key columns with the
+           ORM, an **arbitrary** sorting must be provided; this does not need
+           to match the database sorting, it only needs to be deterministic
+           based on the values of the enum itself (e.g. sorts the same way
+           every time).
+
+           .. versionadded:: 1.3.x
+
+
         """
         self._enum_init(enums, kw)
 
@@ -1377,6 +1389,7 @@ class Enum(Emulated, String, SchemaType):
         self.native_enum = kw.pop("native_enum", True)
         self.create_constraint = kw.pop("create_constraint", True)
         self.values_callable = kw.pop("values_callable", None)
+        self._sort_key_function = kw.pop("sort_key_function", None)
 
         values, objects = self._parse_into_values(enums, kw)
         self._setup_for_values(values, objects, kw)
@@ -1449,6 +1462,10 @@ class Enum(Emulated, String, SchemaType):
             ]
         )
 
+    @property
+    def sort_key_function(self):
+        return self._sort_key_function
+
     @property
     def native(self):
         return self.native_enum
index 9838f0d5afc5c919c5bd665cf2d38db841141c98..8b0287a1d90924e3b2811067a46b6b8c3aaef8bf 100644 (file)
@@ -370,6 +370,18 @@ class TypeEngine(Visitable):
 
         return x == y
 
+    @property
+    def sort_key_function(self):
+        """Return a sorting function that can be passed as the key to sorted.
+
+        Returns None by default, which indicates that the values stored by
+        this type are self-sorting.
+
+        .. versionadded:: 1.3.x
+
+        """
+        return None
+
     def get_dbapi_type(self, dbapi):
         """Return the corresponding type object from the underlying DB-API, if
         any.
@@ -1354,6 +1366,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
         """
         return self.impl.compare_values(x, y)
 
+    @property
+    def sort_key_function(self):
+        return self.impl.sort_key_function
+
     def __repr__(self):
         return util.generic_repr(self, to_inspect=self.impl)
 
index ceec344d9c29239451772583993ae28ca270fe27..93346b32fe40a2e69ce715508ae4fb30b4460073 100644 (file)
@@ -346,7 +346,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
         states[4].insert_order = DontCompareMeToString(1)
         states[2].insert_order = DontCompareMeToString(3)
         eq_(
-            _sort_states(states),
+            _sort_states(m, states),
             [states[4], states[3], states[0], states[1], states[2]],
         )
 
index 6108a28c4cab26e0685a4da1bf03dd964939ecd9..9a25a618dc9de5ef395535300b800a90c4875d5f 100644 (file)
@@ -3,11 +3,15 @@ Primary key changing capabilities and passive/non-passive cascading updates.
 
 """
 
+import itertools
+
 import sqlalchemy as sa
+from sqlalchemy import bindparam
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy import TypeDecorator
 from sqlalchemy.orm import create_session
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
@@ -1754,6 +1758,74 @@ class JoinedInheritanceTest(fixtures.MappedTest):
         )
 
 
+class UnsortablePKTest(fixtures.MappedTest):
+    """Test integration with TypeEngine.sort_key_function"""
+
+    class HashableDict(dict):
+        def __hash__(self):
+            return hash((self["x"], self["y"]))
+
+    @classmethod
+    def define_tables(cls, metadata):
+        class MyUnsortable(TypeDecorator):
+            impl = String(10)
+
+            def process_bind_param(self, value, dialect):
+                return "%s,%s" % (value["x"], value["y"])
+
+            def process_result_value(self, value, dialect):
+                rec = value.split(",")
+                return cls.HashableDict({"x": rec[0], "y": rec[1]})
+
+            def sort_key_function(self, value):
+                return (value["x"], value["y"])
+
+        Table(
+            "data",
+            metadata,
+            Column("info", MyUnsortable(), primary_key=True),
+            Column("int_value", Integer),
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        class Data(cls.Comparable):
+            pass
+
+    @classmethod
+    def setup_mappers(cls):
+        mapper(cls.classes.Data, cls.tables.data)
+
+    def test_updates_sorted(self):
+        Data = self.classes.Data
+        s = Session()
+
+        s.add_all(
+            [
+                Data(info=self.HashableDict(x="a", y="b")),
+                Data(info=self.HashableDict(x="a", y="a")),
+                Data(info=self.HashableDict(x="b", y="b")),
+                Data(info=self.HashableDict(x="b", y="a")),
+            ]
+        )
+        s.commit()
+
+        aa, ab, ba, bb = s.query(Data).order_by(Data.info).all()
+
+        counter = itertools.count()
+        ab.int_value = bindparam(key=None, callable_=lambda: next(counter))
+        ba.int_value = bindparam(key=None, callable_=lambda: next(counter))
+        bb.int_value = bindparam(key=None, callable_=lambda: next(counter))
+        aa.int_value = bindparam(key=None, callable_=lambda: next(counter))
+
+        s.commit()
+
+        eq_(
+            s.query(Data.int_value).order_by(Data.info).all(),
+            [(0,), (1,), (2,), (3,)],
+        )
+
+
 class JoinedInheritancePKOnFKTest(fixtures.MappedTest):
     """Test cascades of pk->non-pk/fk on joined table inh."""
 
index 13c5907a4f03b6ee0519ed8188c73e94dd6d02cb..5b981297245c24470fde7b8b63235e125426ca0c 100644 (file)
@@ -15,12 +15,14 @@ from sqlalchemy import literal_column
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy.inspection import inspect
 from sqlalchemy.orm import column_property
 from sqlalchemy.orm import create_session
 from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
+from sqlalchemy.orm.persistence import _sort_states
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
@@ -3398,6 +3400,7 @@ class EnsurePKSortableTest(fixtures.MappedTest):
     two = MySortableEnum("two", 2)
     three = MyNotSortableEnum("three", 3)
     four = MyNotSortableEnum("four", 4)
+    five = MyNotSortableEnum("five", 5)
 
     @classmethod
     def define_tables(cls, metadata):
@@ -3415,6 +3418,24 @@ class EnsurePKSortableTest(fixtures.MappedTest):
             Column("data", String(10)),
         )
 
+        Table(
+            "t3",
+            metadata,
+            Column(
+                "id",
+                Enum(
+                    cls.MyNotSortableEnum,
+                    sort_key_function=cls.sort_enum_key_value,
+                ),
+                primary_key=True,
+            ),
+            Column("value", Integer),
+        )
+
+    @staticmethod
+    def sort_enum_key_value(value):
+        return value.value
+
     @classmethod
     def setup_classes(cls):
         class T1(cls.Basic):
@@ -3423,10 +3444,15 @@ class EnsurePKSortableTest(fixtures.MappedTest):
         class T2(cls.Basic):
             pass
 
+        class T3(cls.Basic):
+            def __str__(self):
+                return 'T3(id={})'.format(self.id)
+
     @classmethod
     def setup_mappers(cls):
         mapper(cls.classes.T1, cls.tables.t1)
         mapper(cls.classes.T2, cls.tables.t2)
+        mapper(cls.classes.T3, cls.tables.t3)
 
     def test_exception_persistent_flush_py3k(self):
         s = Session()
@@ -3459,3 +3485,21 @@ class EnsurePKSortableTest(fixtures.MappedTest):
         a.data = "bar"
         b.data = "foo"
         s.commit()
+
+    def test_pep435_custom_sort_key(self):
+        s = Session()
+
+        a = self.classes.T3(id=self.three, value=1)
+        b = self.classes.T3(id=self.four, value=2)
+        s.add_all([a, b])
+        s.commit()
+
+        c = self.classes.T3(id=self.five, value=0)
+        s.add(c)
+
+        states = [o._sa_instance_state for o in [b, a, c]]
+        eq_(
+            _sort_states(inspect(self.classes.T3), states),
+            # pending come first, then 3 < 4
+            [o._sa_instance_state for o in [c, a, b]],
+        )
index a5c9313f80d2bda2b9beabfae8b424c642d93ef0..b763d5fddda104e47f04213b24ce144d3fb6434d 100644 (file)
@@ -1658,6 +1658,22 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
             [(1, "two"), (2, "two"), (3, "one")],
         )
 
+    def test_pep435_custom_sort_key(self):
+        def sort_enum_key_value(value):
+            return value.value
+
+        table = Table(
+            "stdlib_enum_table3",
+            MetaData(),
+            Column(
+                "someenum",
+                Enum(self.SomeEnum, sort_key_function=sort_enum_key_value),
+                primary_key=True,
+            ),
+        )
+
+        eq_(table.c.someenum.type.sort_key_function, sort_enum_key_value)
+
     def test_pep435_enum_round_trip(self):
         stdlib_enum_table = self.tables["stdlib_enum_table"]