]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement type-level sorting for Enum; apply to ORM primary keys
authorNicolas CANIART <nicolas@caniart.net>
Thu, 22 Aug 2019 18:16:29 +0000 (14:16 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 27 Aug 2019 16:59:21 +0000 (12:59 -0400)
Added support for the use of an :class:`.Enum` datatype using Python
pep-435 enumeration objects as values for use as a primary key column
mapped by the ORM.  As these values are not inherently sortable, as
required by the ORM for primary keys, a new
:attr:`.TypeEngine.sort_key_function` attribute is added to the typing
system which allows any SQL type to  implement a sorting for Python objects
of its type which is consulted by the unit of work.   The :class:`.Enum`
type then defines this using the  database value of a given enumeration.
The sorting scheme can be  also be redefined by passing a callable to the
:paramref:`.Enum.sort_key_function` parameter.  Pull request courtesy
Nicolas Caniart.

Fixes: #4285
Closes: #4816
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/4816
Pull-request-sha: 42266b766c1e462d5b8a409cda05d33dea13bd34

Change-Id: Iadcc16173c1ba26ffac5830db57743a4cb987c55

doc/build/changelog/unreleased_13/4285.rst [new file with mode: 0644]
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
test/orm/test_mapper.py
test/orm/test_naturalpks.py
test/orm/test_unitofwork.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_13/4285.rst b/doc/build/changelog/unreleased_13/4285.rst
new file mode 100644 (file)
index 0000000..1049a58
--- /dev/null
@@ -0,0 +1,15 @@
+.. change::
+    :tags: usecase, orm
+    :tickets: 4285
+
+    Added support for the use of an :class:`.Enum` datatype using Python
+    pep-435 enumeration objects as values for use as a primary key column
+    mapped by the ORM.  As these values are not inherently sortable, as
+    required by the ORM for primary keys, a new
+    :attr:`.TypeEngine.sort_key_function` attribute is added to the typing
+    system which allows any SQL type to  implement a sorting for Python objects
+    of its type which is consulted by the unit of work.   The :class:`.Enum`
+    type then defines this using the  database value of a given enumeration.
+    The sorting scheme can be  also be redefined by passing a callable to the
+    :paramref:`.Enum.sort_key_function` parameter.  Pull request courtesy
+    Nicolas Caniart.
index 5e8d25647b2d0e503bf00220817b66c55b26b377..07fd9f3fb627ed335ca270c53bedbe7c7f99965b 100644 (file)
@@ -2748,6 +2748,25 @@ class Mapper(InspectionAttr):
         )
         return identity_key[1]
 
+    @_memoized_configured_property
+    def _persistent_sortkey_fn(self):
+        key_fns = [col.type.sort_key_function for col in self.primary_key]
+
+        if set(key_fns).difference([None]):
+
+            def key(state):
+                return tuple(
+                    key_fn(val) if key_fn is not None else val
+                    for key_fn, val in zip(key_fns, state.key[1])
+                )
+
+        else:
+
+            def key(state):
+                return state.key[1]
+
+        return key
+
     @_memoized_configured_property
     def _identity_key_props(self):
         return [self._columntoproperty[col] for col in self.primary_key]
index fb25d2405fcd3b09d7f82fa2b5ea0ac6dc2c6821..68052dfdd1f54abca0530d5f66f50939a8db9043 100644 (file)
@@ -196,7 +196,7 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
 
     # if batch=false, call _save_obj separately for each object
     if not single and not base_mapper.batch:
-        for state in _sort_states(states):
+        for state in _sort_states(base_mapper, states):
             save_obj(base_mapper, [state], uowtransaction, single=True)
         return
 
@@ -1607,7 +1607,7 @@ def _connections_for_states(base_mapper, uowtransaction, states):
         connection = uowtransaction.transaction.connection(base_mapper)
         connection_callable = None
 
-    for state in _sort_states(states):
+    for state in _sort_states(base_mapper, states):
         if connection_callable:
             connection = connection_callable(base_mapper, state.obj())
 
@@ -1625,12 +1625,15 @@ def _cached_connection_dict(base_mapper):
     )
 
 
-def _sort_states(states):
+def _sort_states(mapper, states):
     pending = set(states)
     persistent = set(s for s in pending if s.key is not None)
     pending.difference_update(persistent)
+
     try:
-        persistent_sorted = sorted(persistent, key=lambda q: q.key[1])
+        persistent_sorted = sorted(
+            persistent, key=mapper._persistent_sortkey_fn
+        )
     except TypeError as err:
         raise sa_exc.InvalidRequestError(
             "Could not sort objects by primary key; primary key "
index 631352ceb7986d05de70c34cd8df7354fc34ff48..fd15d7c795ef142da7bfac0a5d9e26c308e02086 100644 (file)
@@ -20,6 +20,7 @@ from . import operators
 from . import roles
 from . import type_api
 from .base import _bind_or_error
+from .base import NO_ARG
 from .base import SchemaEventTarget
 from .elements import _defer_name
 from .elements import quoted_name
@@ -1356,6 +1357,19 @@ class Enum(Emulated, String, SchemaType):
 
            .. versionadded:: 1.2.3
 
+        :param sort_key_function: a Python callable which may be used as the
+           "key" argument in the Python ``sorted()`` built-in.   The SQLAlchemy
+           ORM requires that primary key columns which are mapped must
+           be sortable in some way.  When using an unsortable enumeration
+           object such as a Python 3 ``Enum`` object, this parameter may be
+           used to set a default sort key function for the objects.  By
+           default, the database value of the enumeration is used as the
+           sorting function.
+
+            .. versionadded:: 1.3.8
+
+
+
         """
         self._enum_init(enums, kw)
 
@@ -1377,6 +1391,7 @@ class Enum(Emulated, String, SchemaType):
         self.native_enum = kw.pop("native_enum", True)
         self.create_constraint = kw.pop("create_constraint", True)
         self.values_callable = kw.pop("values_callable", None)
+        self._sort_key_function = kw.pop("sort_key_function", NO_ARG)
 
         values, objects = self._parse_into_values(enums, kw)
         self._setup_for_values(values, objects, kw)
@@ -1449,6 +1464,13 @@ class Enum(Emulated, String, SchemaType):
             ]
         )
 
+    @property
+    def sort_key_function(self):
+        if self._sort_key_function is NO_ARG:
+            return self._db_value_for_elem
+        else:
+            return self._sort_key_function
+
     @property
     def native(self):
         return self.native_enum
index 9838f0d5afc5c919c5bd665cf2d38db841141c98..11407ad2e193ea406e5514421f793547218b424f 100644 (file)
@@ -135,6 +135,16 @@ class TypeEngine(Visitable):
 
     """
 
+    sort_key_function = None
+    """A sorting function that can be passed as the key to sorted.
+
+    The default value of ``None`` indicates that the values stored by
+    this type are self-sorting.
+
+    .. versionadded:: 1.3.8
+
+    """
+
     should_evaluate_none = False
     """If True, the Python constant ``None`` is considered to be handled
     explicitly by this type.
@@ -1354,6 +1364,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
         """
         return self.impl.compare_values(x, y)
 
+    @property
+    def sort_key_function(self):
+        return self.impl.sort_key_function
+
     def __repr__(self):
         return util.generic_repr(self, to_inspect=self.impl)
 
index ceec344d9c29239451772583993ae28ca270fe27..93346b32fe40a2e69ce715508ae4fb30b4460073 100644 (file)
@@ -346,7 +346,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
         states[4].insert_order = DontCompareMeToString(1)
         states[2].insert_order = DontCompareMeToString(3)
         eq_(
-            _sort_states(states),
+            _sort_states(m, states),
             [states[4], states[3], states[0], states[1], states[2]],
         )
 
index 6108a28c4cab26e0685a4da1bf03dd964939ecd9..9a25a618dc9de5ef395535300b800a90c4875d5f 100644 (file)
@@ -3,11 +3,15 @@ Primary key changing capabilities and passive/non-passive cascading updates.
 
 """
 
+import itertools
+
 import sqlalchemy as sa
+from sqlalchemy import bindparam
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy import TypeDecorator
 from sqlalchemy.orm import create_session
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
@@ -1754,6 +1758,74 @@ class JoinedInheritanceTest(fixtures.MappedTest):
         )
 
 
+class UnsortablePKTest(fixtures.MappedTest):
+    """Test integration with TypeEngine.sort_key_function"""
+
+    class HashableDict(dict):
+        def __hash__(self):
+            return hash((self["x"], self["y"]))
+
+    @classmethod
+    def define_tables(cls, metadata):
+        class MyUnsortable(TypeDecorator):
+            impl = String(10)
+
+            def process_bind_param(self, value, dialect):
+                return "%s,%s" % (value["x"], value["y"])
+
+            def process_result_value(self, value, dialect):
+                rec = value.split(",")
+                return cls.HashableDict({"x": rec[0], "y": rec[1]})
+
+            def sort_key_function(self, value):
+                return (value["x"], value["y"])
+
+        Table(
+            "data",
+            metadata,
+            Column("info", MyUnsortable(), primary_key=True),
+            Column("int_value", Integer),
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        class Data(cls.Comparable):
+            pass
+
+    @classmethod
+    def setup_mappers(cls):
+        mapper(cls.classes.Data, cls.tables.data)
+
+    def test_updates_sorted(self):
+        Data = self.classes.Data
+        s = Session()
+
+        s.add_all(
+            [
+                Data(info=self.HashableDict(x="a", y="b")),
+                Data(info=self.HashableDict(x="a", y="a")),
+                Data(info=self.HashableDict(x="b", y="b")),
+                Data(info=self.HashableDict(x="b", y="a")),
+            ]
+        )
+        s.commit()
+
+        aa, ab, ba, bb = s.query(Data).order_by(Data.info).all()
+
+        counter = itertools.count()
+        ab.int_value = bindparam(key=None, callable_=lambda: next(counter))
+        ba.int_value = bindparam(key=None, callable_=lambda: next(counter))
+        bb.int_value = bindparam(key=None, callable_=lambda: next(counter))
+        aa.int_value = bindparam(key=None, callable_=lambda: next(counter))
+
+        s.commit()
+
+        eq_(
+            s.query(Data.int_value).order_by(Data.info).all(),
+            [(0,), (1,), (2,), (3,)],
+        )
+
+
 class JoinedInheritancePKOnFKTest(fixtures.MappedTest):
     """Test cascades of pk->non-pk/fk on joined table inh."""
 
index 13c5907a4f03b6ee0519ed8188c73e94dd6d02cb..6185c4a51c4c074a4f55411b57cf9693c596351c 100644 (file)
@@ -15,12 +15,14 @@ from sqlalchemy import literal_column
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy.inspection import inspect
 from sqlalchemy.orm import column_property
 from sqlalchemy.orm import create_session
 from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
+from sqlalchemy.orm.persistence import _sort_states
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
@@ -3398,6 +3400,7 @@ class EnsurePKSortableTest(fixtures.MappedTest):
     two = MySortableEnum("two", 2)
     three = MyNotSortableEnum("three", 3)
     four = MyNotSortableEnum("four", 4)
+    five = MyNotSortableEnum("five", 5)
 
     @classmethod
     def define_tables(cls, metadata):
@@ -3411,10 +3414,25 @@ class EnsurePKSortableTest(fixtures.MappedTest):
         Table(
             "t2",
             metadata,
-            Column("id", Enum(cls.MyNotSortableEnum), primary_key=True),
+            Column(
+                "id",
+                Enum(cls.MyNotSortableEnum, sort_key_function=None),
+                primary_key=True,
+            ),
             Column("data", String(10)),
         )
 
+        Table(
+            "t3",
+            metadata,
+            Column("id", Enum(cls.MyNotSortableEnum), primary_key=True),
+            Column("value", Integer),
+        )
+
+    @staticmethod
+    def sort_enum_key_value(value):
+        return value.value
+
     @classmethod
     def setup_classes(cls):
         class T1(cls.Basic):
@@ -3423,10 +3441,15 @@ class EnsurePKSortableTest(fixtures.MappedTest):
         class T2(cls.Basic):
             pass
 
+        class T3(cls.Basic):
+            def __str__(self):
+                return "T3(id={})".format(self.id)
+
     @classmethod
     def setup_mappers(cls):
         mapper(cls.classes.T1, cls.tables.t1)
         mapper(cls.classes.T2, cls.tables.t2)
+        mapper(cls.classes.T3, cls.tables.t3)
 
     def test_exception_persistent_flush_py3k(self):
         s = Session()
@@ -3459,3 +3482,21 @@ class EnsurePKSortableTest(fixtures.MappedTest):
         a.data = "bar"
         b.data = "foo"
         s.commit()
+
+    def test_pep435_custom_sort_key(self):
+        s = Session()
+
+        a = self.classes.T3(id=self.three, value=1)
+        b = self.classes.T3(id=self.four, value=2)
+        s.add_all([a, b])
+        s.commit()
+
+        c = self.classes.T3(id=self.five, value=0)
+        s.add(c)
+
+        states = [o._sa_instance_state for o in [b, a, c]]
+        eq_(
+            _sort_states(inspect(self.classes.T3), states),
+            # pending come first, then "four" < "three"
+            [o._sa_instance_state for o in [c, b, a]],
+        )
index a5c9313f80d2bda2b9beabfae8b424c642d93ef0..e3d2134b7a0d3b0221f5323165ed77afea8cd15a 100644 (file)
@@ -1658,6 +1658,45 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
             [(1, "two"), (2, "two"), (3, "one")],
         )
 
+    def test_pep435_default_sort_key(self):
+        one, two, a_member, b_member = (
+            self.one,
+            self.two,
+            self.a_member,
+            self.b_member,
+        )
+        typ = Enum(self.SomeEnum)
+
+        is_(typ.sort_key_function.__func__, typ._db_value_for_elem.__func__)
+
+        eq_(
+            sorted([two, one, a_member, b_member], key=typ.sort_key_function),
+            [a_member, b_member, one, two],
+        )
+
+    def test_pep435_custom_sort_key(self):
+        one, two, a_member, b_member = (
+            self.one,
+            self.two,
+            self.a_member,
+            self.b_member,
+        )
+
+        def sort_enum_key_value(value):
+            return str(value.value)
+
+        typ = Enum(self.SomeEnum, sort_key_function=sort_enum_key_value)
+        is_(typ.sort_key_function, sort_enum_key_value)
+
+        eq_(
+            sorted([two, one, a_member, b_member], key=typ.sort_key_function),
+            [one, two, a_member, b_member],
+        )
+
+    def test_pep435_no_sort_key(self):
+        typ = Enum(self.SomeEnum, sort_key_function=None)
+        is_(typ.sort_key_function, None)
+
     def test_pep435_enum_round_trip(self):
         stdlib_enum_table = self.tables["stdlib_enum_table"]