]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement type-level sorting for Enum; apply to ORM primary keys
authorNicolas CANIART <nicolas@caniart.net>
Thu, 22 Aug 2019 18:16:29 +0000 (14:16 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 27 Aug 2019 16:59:32 +0000 (12:59 -0400)
Added support for the use of an :class:`.Enum` datatype using Python
pep-435 enumeration objects as values for use as a primary key column
mapped by the ORM.  As these values are not inherently sortable, as
required by the ORM for primary keys, a new
:attr:`.TypeEngine.sort_key_function` attribute is added to the typing
system which allows any SQL type to  implement a sorting for Python objects
of its type which is consulted by the unit of work.   The :class:`.Enum`
type then defines this using the  database value of a given enumeration.
The sorting scheme can be  also be redefined by passing a callable to the
:paramref:`.Enum.sort_key_function` parameter.  Pull request courtesy
Nicolas Caniart.

Fixes: #4285
Closes: #4816
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/4816
Pull-request-sha: 42266b766c1e462d5b8a409cda05d33dea13bd34

Change-Id: Iadcc16173c1ba26ffac5830db57743a4cb987c55
(cherry picked from commit 75b2518b2659796c885396fd0893dd7f9b19a9ef)

doc/build/changelog/unreleased_13/4285.rst [new file with mode: 0644]
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
test/orm/test_mapper.py
test/orm/test_naturalpks.py
test/orm/test_unitofwork.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_13/4285.rst b/doc/build/changelog/unreleased_13/4285.rst
new file mode 100644 (file)
index 0000000..1049a58
--- /dev/null
@@ -0,0 +1,15 @@
+.. change::
+    :tags: usecase, orm
+    :tickets: 4285
+
+    Added support for the use of an :class:`.Enum` datatype using Python
+    pep-435 enumeration objects as values for use as a primary key column
+    mapped by the ORM.  As these values are not inherently sortable, as
+    required by the ORM for primary keys, a new
+    :attr:`.TypeEngine.sort_key_function` attribute is added to the typing
+    system which allows any SQL type to  implement a sorting for Python objects
+    of its type which is consulted by the unit of work.   The :class:`.Enum`
+    type then defines this using the  database value of a given enumeration.
+    The sorting scheme can be  also be redefined by passing a callable to the
+    :paramref:`.Enum.sort_key_function` parameter.  Pull request courtesy
+    Nicolas Caniart.
index 6eadffb16030d36efd2f09e95a1e23b8a6cf6785..3d62c8b4aeb85435e3782429d259b74254392a2a 100644 (file)
@@ -2747,6 +2747,25 @@ class Mapper(InspectionAttr):
         )
         return identity_key[1]
 
+    @_memoized_configured_property
+    def _persistent_sortkey_fn(self):
+        key_fns = [col.type.sort_key_function for col in self.primary_key]
+
+        if set(key_fns).difference([None]):
+
+            def key(state):
+                return tuple(
+                    key_fn(val) if key_fn is not None else val
+                    for key_fn, val in zip(key_fns, state.key[1])
+                )
+
+        else:
+
+            def key(state):
+                return state.key[1]
+
+        return key
+
     @_memoized_configured_property
     def _identity_key_props(self):
         return [self._columntoproperty[col] for col in self.primary_key]
index e837e46001b25051f150d3970e34a8cf630cf187..38171966a6eb6ca08d2bc2561df35679fe7d4f62 100644 (file)
@@ -194,7 +194,7 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
 
     # if batch=false, call _save_obj separately for each object
     if not single and not base_mapper.batch:
-        for state in _sort_states(states):
+        for state in _sort_states(base_mapper, states):
             save_obj(base_mapper, [state], uowtransaction, single=True)
         return
 
@@ -1605,7 +1605,7 @@ def _connections_for_states(base_mapper, uowtransaction, states):
         connection = uowtransaction.transaction.connection(base_mapper)
         connection_callable = None
 
-    for state in _sort_states(states):
+    for state in _sort_states(base_mapper, states):
         if connection_callable:
             connection = connection_callable(base_mapper, state.obj())
 
@@ -1623,12 +1623,15 @@ def _cached_connection_dict(base_mapper):
     )
 
 
-def _sort_states(states):
+def _sort_states(mapper, states):
     pending = set(states)
     persistent = set(s for s in pending if s.key is not None)
     pending.difference_update(persistent)
+
     try:
-        persistent_sorted = sorted(persistent, key=lambda q: q.key[1])
+        persistent_sorted = sorted(
+            persistent, key=mapper._persistent_sortkey_fn
+        )
     except TypeError as err:
         raise sa_exc.InvalidRequestError(
             "Could not sort objects by primary key; primary key "
index aeac1587c0ca2c5f4d490d6a4891b0c100ff4c35..0c32a6d51b4ebbe55ea9c3fc795c8ccbf451a7e3 100644 (file)
@@ -18,6 +18,7 @@ from . import elements
 from . import operators
 from . import type_api
 from .base import _bind_or_error
+from .base import NO_ARG
 from .base import SchemaEventTarget
 from .elements import _defer_name
 from .elements import _literal_as_binds
@@ -1355,6 +1356,19 @@ class Enum(Emulated, String, SchemaType):
 
            .. versionadded:: 1.2.3
 
+        :param sort_key_function: a Python callable which may be used as the
+           "key" argument in the Python ``sorted()`` built-in.   The SQLAlchemy
+           ORM requires that primary key columns which are mapped must
+           be sortable in some way.  When using an unsortable enumeration
+           object such as a Python 3 ``Enum`` object, this parameter may be
+           used to set a default sort key function for the objects.  By
+           default, the database value of the enumeration is used as the
+           sorting function.
+
+            .. versionadded:: 1.3.8
+
+
+
         """
         self._enum_init(enums, kw)
 
@@ -1376,6 +1390,7 @@ class Enum(Emulated, String, SchemaType):
         self.native_enum = kw.pop("native_enum", True)
         self.create_constraint = kw.pop("create_constraint", True)
         self.values_callable = kw.pop("values_callable", None)
+        self._sort_key_function = kw.pop("sort_key_function", NO_ARG)
 
         values, objects = self._parse_into_values(enums, kw)
         self._setup_for_values(values, objects, kw)
@@ -1448,6 +1463,13 @@ class Enum(Emulated, String, SchemaType):
             ]
         )
 
+    @property
+    def sort_key_function(self):
+        if self._sort_key_function is NO_ARG:
+            return self._db_value_for_elem
+        else:
+            return self._sort_key_function
+
     @property
     def native(self):
         return self.native_enum
index 5f96b1aad8a41ccdb78a4d4b481c8620e3b03303..fa61abdf174091c59a68f261d8d3a7a8bc61d564 100644 (file)
@@ -132,6 +132,16 @@ class TypeEngine(Visitable):
 
     """
 
+    sort_key_function = None
+    """A sorting function that can be passed as the key to sorted.
+
+    The default value of ``None`` indicates that the values stored by
+    this type are self-sorting.
+
+    .. versionadded:: 1.3.8
+
+    """
+
     should_evaluate_none = False
     """If True, the Python constant ``None`` is considered to be handled
     explicitly by this type.
@@ -1347,6 +1357,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
         """
         return self.impl.compare_values(x, y)
 
+    @property
+    def sort_key_function(self):
+        return self.impl.sort_key_function
+
     def __repr__(self):
         return util.generic_repr(self, to_inspect=self.impl)
 
index 08f3eca2b8b88496b5e02c8f28d4be69d53f9ee5..72e428efcd120797a8fd2d254fc412f494ea63d6 100644 (file)
@@ -346,7 +346,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
         states[4].insert_order = DontCompareMeToString(1)
         states[2].insert_order = DontCompareMeToString(3)
         eq_(
-            _sort_states(states),
+            _sort_states(m, states),
             [states[4], states[3], states[0], states[1], states[2]],
         )
 
index 6108a28c4cab26e0685a4da1bf03dd964939ecd9..9a25a618dc9de5ef395535300b800a90c4875d5f 100644 (file)
@@ -3,11 +3,15 @@ Primary key changing capabilities and passive/non-passive cascading updates.
 
 """
 
+import itertools
+
 import sqlalchemy as sa
+from sqlalchemy import bindparam
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy import TypeDecorator
 from sqlalchemy.orm import create_session
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
@@ -1754,6 +1758,74 @@ class JoinedInheritanceTest(fixtures.MappedTest):
         )
 
 
+class UnsortablePKTest(fixtures.MappedTest):
+    """Test integration with TypeEngine.sort_key_function"""
+
+    class HashableDict(dict):
+        def __hash__(self):
+            return hash((self["x"], self["y"]))
+
+    @classmethod
+    def define_tables(cls, metadata):
+        class MyUnsortable(TypeDecorator):
+            impl = String(10)
+
+            def process_bind_param(self, value, dialect):
+                return "%s,%s" % (value["x"], value["y"])
+
+            def process_result_value(self, value, dialect):
+                rec = value.split(",")
+                return cls.HashableDict({"x": rec[0], "y": rec[1]})
+
+            def sort_key_function(self, value):
+                return (value["x"], value["y"])
+
+        Table(
+            "data",
+            metadata,
+            Column("info", MyUnsortable(), primary_key=True),
+            Column("int_value", Integer),
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        class Data(cls.Comparable):
+            pass
+
+    @classmethod
+    def setup_mappers(cls):
+        mapper(cls.classes.Data, cls.tables.data)
+
+    def test_updates_sorted(self):
+        Data = self.classes.Data
+        s = Session()
+
+        s.add_all(
+            [
+                Data(info=self.HashableDict(x="a", y="b")),
+                Data(info=self.HashableDict(x="a", y="a")),
+                Data(info=self.HashableDict(x="b", y="b")),
+                Data(info=self.HashableDict(x="b", y="a")),
+            ]
+        )
+        s.commit()
+
+        aa, ab, ba, bb = s.query(Data).order_by(Data.info).all()
+
+        counter = itertools.count()
+        ab.int_value = bindparam(key=None, callable_=lambda: next(counter))
+        ba.int_value = bindparam(key=None, callable_=lambda: next(counter))
+        bb.int_value = bindparam(key=None, callable_=lambda: next(counter))
+        aa.int_value = bindparam(key=None, callable_=lambda: next(counter))
+
+        s.commit()
+
+        eq_(
+            s.query(Data.int_value).order_by(Data.info).all(),
+            [(0,), (1,), (2,), (3,)],
+        )
+
+
 class JoinedInheritancePKOnFKTest(fixtures.MappedTest):
     """Test cascades of pk->non-pk/fk on joined table inh."""
 
index 13c5907a4f03b6ee0519ed8188c73e94dd6d02cb..6185c4a51c4c074a4f55411b57cf9693c596351c 100644 (file)
@@ -15,12 +15,14 @@ from sqlalchemy import literal_column
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy.inspection import inspect
 from sqlalchemy.orm import column_property
 from sqlalchemy.orm import create_session
 from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
+from sqlalchemy.orm.persistence import _sort_states
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
@@ -3398,6 +3400,7 @@ class EnsurePKSortableTest(fixtures.MappedTest):
     two = MySortableEnum("two", 2)
     three = MyNotSortableEnum("three", 3)
     four = MyNotSortableEnum("four", 4)
+    five = MyNotSortableEnum("five", 5)
 
     @classmethod
     def define_tables(cls, metadata):
@@ -3411,10 +3414,25 @@ class EnsurePKSortableTest(fixtures.MappedTest):
         Table(
             "t2",
             metadata,
-            Column("id", Enum(cls.MyNotSortableEnum), primary_key=True),
+            Column(
+                "id",
+                Enum(cls.MyNotSortableEnum, sort_key_function=None),
+                primary_key=True,
+            ),
             Column("data", String(10)),
         )
 
+        Table(
+            "t3",
+            metadata,
+            Column("id", Enum(cls.MyNotSortableEnum), primary_key=True),
+            Column("value", Integer),
+        )
+
+    @staticmethod
+    def sort_enum_key_value(value):
+        return value.value
+
     @classmethod
     def setup_classes(cls):
         class T1(cls.Basic):
@@ -3423,10 +3441,15 @@ class EnsurePKSortableTest(fixtures.MappedTest):
         class T2(cls.Basic):
             pass
 
+        class T3(cls.Basic):
+            def __str__(self):
+                return "T3(id={})".format(self.id)
+
     @classmethod
     def setup_mappers(cls):
         mapper(cls.classes.T1, cls.tables.t1)
         mapper(cls.classes.T2, cls.tables.t2)
+        mapper(cls.classes.T3, cls.tables.t3)
 
     def test_exception_persistent_flush_py3k(self):
         s = Session()
@@ -3459,3 +3482,21 @@ class EnsurePKSortableTest(fixtures.MappedTest):
         a.data = "bar"
         b.data = "foo"
         s.commit()
+
+    def test_pep435_custom_sort_key(self):
+        s = Session()
+
+        a = self.classes.T3(id=self.three, value=1)
+        b = self.classes.T3(id=self.four, value=2)
+        s.add_all([a, b])
+        s.commit()
+
+        c = self.classes.T3(id=self.five, value=0)
+        s.add(c)
+
+        states = [o._sa_instance_state for o in [b, a, c]]
+        eq_(
+            _sort_states(inspect(self.classes.T3), states),
+            # pending come first, then "four" < "three"
+            [o._sa_instance_state for o in [c, b, a]],
+        )
index a1b1f024b91d7a66e72dd56b8be70c4617aabfa6..491c1cf7f55c6f8e2189aa037dac1cc9af9322d8 100644 (file)
@@ -1658,6 +1658,45 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
             [(1, "two"), (2, "two"), (3, "one")],
         )
 
+    def test_pep435_default_sort_key(self):
+        one, two, a_member, b_member = (
+            self.one,
+            self.two,
+            self.a_member,
+            self.b_member,
+        )
+        typ = Enum(self.SomeEnum)
+
+        is_(typ.sort_key_function.__func__, typ._db_value_for_elem.__func__)
+
+        eq_(
+            sorted([two, one, a_member, b_member], key=typ.sort_key_function),
+            [a_member, b_member, one, two],
+        )
+
+    def test_pep435_custom_sort_key(self):
+        one, two, a_member, b_member = (
+            self.one,
+            self.two,
+            self.a_member,
+            self.b_member,
+        )
+
+        def sort_enum_key_value(value):
+            return str(value.value)
+
+        typ = Enum(self.SomeEnum, sort_key_function=sort_enum_key_value)
+        is_(typ.sort_key_function, sort_enum_key_value)
+
+        eq_(
+            sorted([two, one, a_member, b_member], key=typ.sort_key_function),
+            [one, two, a_member, b_member],
+        )
+
+    def test_pep435_no_sort_key(self):
+        typ = Enum(self.SomeEnum, sort_key_function=None)
+        is_(typ.sort_key_function, None)
+
     def test_pep435_enum_round_trip(self):
         stdlib_enum_table = self.tables["stdlib_enum_table"]