From: Nicolas CANIART Date: Thu, 22 Aug 2019 18:16:29 +0000 (-0400) Subject: Implement type-level sorting for Enum; apply to ORM primary keys X-Git-Tag: rel_1_3_8~1^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a35740b509904fd3ac845aeac50da5f80e6d14e6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Implement type-level sorting for Enum; apply to ORM primary keys Added support for the use of an :class:`.Enum` datatype using Python pep-435 enumeration objects as values for use as a primary key column mapped by the ORM. As these values are not inherently sortable, as required by the ORM for primary keys, a new :attr:`.TypeEngine.sort_key_function` attribute is added to the typing system which allows any SQL type to implement a sorting for Python objects of its type which is consulted by the unit of work. The :class:`.Enum` type then defines this using the database value of a given enumeration. The sorting scheme can be also be redefined by passing a callable to the :paramref:`.Enum.sort_key_function` parameter. Pull request courtesy Nicolas Caniart. Fixes: #4285 Closes: #4816 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/4816 Pull-request-sha: 42266b766c1e462d5b8a409cda05d33dea13bd34 Change-Id: Iadcc16173c1ba26ffac5830db57743a4cb987c55 (cherry picked from commit 75b2518b2659796c885396fd0893dd7f9b19a9ef) --- diff --git a/doc/build/changelog/unreleased_13/4285.rst b/doc/build/changelog/unreleased_13/4285.rst new file mode 100644 index 0000000000..1049a5882b --- /dev/null +++ b/doc/build/changelog/unreleased_13/4285.rst @@ -0,0 +1,15 @@ +.. change:: + :tags: usecase, orm + :tickets: 4285 + + Added support for the use of an :class:`.Enum` datatype using Python + pep-435 enumeration objects as values for use as a primary key column + mapped by the ORM. As these values are not inherently sortable, as + required by the ORM for primary keys, a new + :attr:`.TypeEngine.sort_key_function` attribute is added to the typing + system which allows any SQL type to implement a sorting for Python objects + of its type which is consulted by the unit of work. The :class:`.Enum` + type then defines this using the database value of a given enumeration. + The sorting scheme can be also be redefined by passing a callable to the + :paramref:`.Enum.sort_key_function` parameter. Pull request courtesy + Nicolas Caniart. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 6eadffb160..3d62c8b4ae 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -2747,6 +2747,25 @@ class Mapper(InspectionAttr): ) return identity_key[1] + @_memoized_configured_property + def _persistent_sortkey_fn(self): + key_fns = [col.type.sort_key_function for col in self.primary_key] + + if set(key_fns).difference([None]): + + def key(state): + return tuple( + key_fn(val) if key_fn is not None else val + for key_fn, val in zip(key_fns, state.key[1]) + ) + + else: + + def key(state): + return state.key[1] + + return key + @_memoized_configured_property def _identity_key_props(self): return [self._columntoproperty[col] for col in self.primary_key] diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index e837e46001..38171966a6 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -194,7 +194,7 @@ def save_obj(base_mapper, states, uowtransaction, single=False): # if batch=false, call _save_obj separately for each object if not single and not base_mapper.batch: - for state in _sort_states(states): + for state in _sort_states(base_mapper, states): save_obj(base_mapper, [state], uowtransaction, single=True) return @@ -1605,7 +1605,7 @@ def _connections_for_states(base_mapper, uowtransaction, states): connection = uowtransaction.transaction.connection(base_mapper) connection_callable = None - for state in _sort_states(states): + for state in _sort_states(base_mapper, states): if connection_callable: connection = connection_callable(base_mapper, state.obj()) @@ -1623,12 +1623,15 @@ def _cached_connection_dict(base_mapper): ) -def _sort_states(states): +def _sort_states(mapper, states): pending = set(states) persistent = set(s for s in pending if s.key is not None) pending.difference_update(persistent) + try: - persistent_sorted = sorted(persistent, key=lambda q: q.key[1]) + persistent_sorted = sorted( + persistent, key=mapper._persistent_sortkey_fn + ) except TypeError as err: raise sa_exc.InvalidRequestError( "Could not sort objects by primary key; primary key " diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index aeac1587c0..0c32a6d51b 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -18,6 +18,7 @@ from . import elements from . import operators from . import type_api from .base import _bind_or_error +from .base import NO_ARG from .base import SchemaEventTarget from .elements import _defer_name from .elements import _literal_as_binds @@ -1355,6 +1356,19 @@ class Enum(Emulated, String, SchemaType): .. versionadded:: 1.2.3 + :param sort_key_function: a Python callable which may be used as the + "key" argument in the Python ``sorted()`` built-in. The SQLAlchemy + ORM requires that primary key columns which are mapped must + be sortable in some way. When using an unsortable enumeration + object such as a Python 3 ``Enum`` object, this parameter may be + used to set a default sort key function for the objects. By + default, the database value of the enumeration is used as the + sorting function. + + .. versionadded:: 1.3.8 + + + """ self._enum_init(enums, kw) @@ -1376,6 +1390,7 @@ class Enum(Emulated, String, SchemaType): self.native_enum = kw.pop("native_enum", True) self.create_constraint = kw.pop("create_constraint", True) self.values_callable = kw.pop("values_callable", None) + self._sort_key_function = kw.pop("sort_key_function", NO_ARG) values, objects = self._parse_into_values(enums, kw) self._setup_for_values(values, objects, kw) @@ -1448,6 +1463,13 @@ class Enum(Emulated, String, SchemaType): ] ) + @property + def sort_key_function(self): + if self._sort_key_function is NO_ARG: + return self._db_value_for_elem + else: + return self._sort_key_function + @property def native(self): return self.native_enum diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 5f96b1aad8..fa61abdf17 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -132,6 +132,16 @@ class TypeEngine(Visitable): """ + sort_key_function = None + """A sorting function that can be passed as the key to sorted. + + The default value of ``None`` indicates that the values stored by + this type are self-sorting. + + .. versionadded:: 1.3.8 + + """ + should_evaluate_none = False """If True, the Python constant ``None`` is considered to be handled explicitly by this type. @@ -1347,6 +1357,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): """ return self.impl.compare_values(x, y) + @property + def sort_key_function(self): + return self.impl.sort_key_function + def __repr__(self): return util.generic_repr(self, to_inspect=self.impl) diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 08f3eca2b8..72e428efcd 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -346,7 +346,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): states[4].insert_order = DontCompareMeToString(1) states[2].insert_order = DontCompareMeToString(3) eq_( - _sort_states(states), + _sort_states(m, states), [states[4], states[3], states[0], states[1], states[2]], ) diff --git a/test/orm/test_naturalpks.py b/test/orm/test_naturalpks.py index 6108a28c4c..9a25a618dc 100644 --- a/test/orm/test_naturalpks.py +++ b/test/orm/test_naturalpks.py @@ -3,11 +3,15 @@ Primary key changing capabilities and passive/non-passive cascading updates. """ +import itertools + import sqlalchemy as sa +from sqlalchemy import bindparam from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import testing +from sqlalchemy import TypeDecorator from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship @@ -1754,6 +1758,74 @@ class JoinedInheritanceTest(fixtures.MappedTest): ) +class UnsortablePKTest(fixtures.MappedTest): + """Test integration with TypeEngine.sort_key_function""" + + class HashableDict(dict): + def __hash__(self): + return hash((self["x"], self["y"])) + + @classmethod + def define_tables(cls, metadata): + class MyUnsortable(TypeDecorator): + impl = String(10) + + def process_bind_param(self, value, dialect): + return "%s,%s" % (value["x"], value["y"]) + + def process_result_value(self, value, dialect): + rec = value.split(",") + return cls.HashableDict({"x": rec[0], "y": rec[1]}) + + def sort_key_function(self, value): + return (value["x"], value["y"]) + + Table( + "data", + metadata, + Column("info", MyUnsortable(), primary_key=True), + Column("int_value", Integer), + ) + + @classmethod + def setup_classes(cls): + class Data(cls.Comparable): + pass + + @classmethod + def setup_mappers(cls): + mapper(cls.classes.Data, cls.tables.data) + + def test_updates_sorted(self): + Data = self.classes.Data + s = Session() + + s.add_all( + [ + Data(info=self.HashableDict(x="a", y="b")), + Data(info=self.HashableDict(x="a", y="a")), + Data(info=self.HashableDict(x="b", y="b")), + Data(info=self.HashableDict(x="b", y="a")), + ] + ) + s.commit() + + aa, ab, ba, bb = s.query(Data).order_by(Data.info).all() + + counter = itertools.count() + ab.int_value = bindparam(key=None, callable_=lambda: next(counter)) + ba.int_value = bindparam(key=None, callable_=lambda: next(counter)) + bb.int_value = bindparam(key=None, callable_=lambda: next(counter)) + aa.int_value = bindparam(key=None, callable_=lambda: next(counter)) + + s.commit() + + eq_( + s.query(Data.int_value).order_by(Data.info).all(), + [(0,), (1,), (2,), (3,)], + ) + + class JoinedInheritancePKOnFKTest(fixtures.MappedTest): """Test cascades of pk->non-pk/fk on joined table inh.""" diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index 13c5907a4f..6185c4a51c 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -15,12 +15,14 @@ from sqlalchemy import literal_column from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing +from sqlalchemy.inspection import inspect from sqlalchemy.orm import column_property from sqlalchemy.orm import create_session from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.orm import Session +from sqlalchemy.orm.persistence import _sort_states from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ @@ -3398,6 +3400,7 @@ class EnsurePKSortableTest(fixtures.MappedTest): two = MySortableEnum("two", 2) three = MyNotSortableEnum("three", 3) four = MyNotSortableEnum("four", 4) + five = MyNotSortableEnum("five", 5) @classmethod def define_tables(cls, metadata): @@ -3411,10 +3414,25 @@ class EnsurePKSortableTest(fixtures.MappedTest): Table( "t2", metadata, - Column("id", Enum(cls.MyNotSortableEnum), primary_key=True), + Column( + "id", + Enum(cls.MyNotSortableEnum, sort_key_function=None), + primary_key=True, + ), Column("data", String(10)), ) + Table( + "t3", + metadata, + Column("id", Enum(cls.MyNotSortableEnum), primary_key=True), + Column("value", Integer), + ) + + @staticmethod + def sort_enum_key_value(value): + return value.value + @classmethod def setup_classes(cls): class T1(cls.Basic): @@ -3423,10 +3441,15 @@ class EnsurePKSortableTest(fixtures.MappedTest): class T2(cls.Basic): pass + class T3(cls.Basic): + def __str__(self): + return "T3(id={})".format(self.id) + @classmethod def setup_mappers(cls): mapper(cls.classes.T1, cls.tables.t1) mapper(cls.classes.T2, cls.tables.t2) + mapper(cls.classes.T3, cls.tables.t3) def test_exception_persistent_flush_py3k(self): s = Session() @@ -3459,3 +3482,21 @@ class EnsurePKSortableTest(fixtures.MappedTest): a.data = "bar" b.data = "foo" s.commit() + + def test_pep435_custom_sort_key(self): + s = Session() + + a = self.classes.T3(id=self.three, value=1) + b = self.classes.T3(id=self.four, value=2) + s.add_all([a, b]) + s.commit() + + c = self.classes.T3(id=self.five, value=0) + s.add(c) + + states = [o._sa_instance_state for o in [b, a, c]] + eq_( + _sort_states(inspect(self.classes.T3), states), + # pending come first, then "four" < "three" + [o._sa_instance_state for o in [c, b, a]], + ) diff --git a/test/sql/test_types.py b/test/sql/test_types.py index a1b1f024b9..491c1cf7f5 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1658,6 +1658,45 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): [(1, "two"), (2, "two"), (3, "one")], ) + def test_pep435_default_sort_key(self): + one, two, a_member, b_member = ( + self.one, + self.two, + self.a_member, + self.b_member, + ) + typ = Enum(self.SomeEnum) + + is_(typ.sort_key_function.__func__, typ._db_value_for_elem.__func__) + + eq_( + sorted([two, one, a_member, b_member], key=typ.sort_key_function), + [a_member, b_member, one, two], + ) + + def test_pep435_custom_sort_key(self): + one, two, a_member, b_member = ( + self.one, + self.two, + self.a_member, + self.b_member, + ) + + def sort_enum_key_value(value): + return str(value.value) + + typ = Enum(self.SomeEnum, sort_key_function=sort_enum_key_value) + is_(typ.sort_key_function, sort_enum_key_value) + + eq_( + sorted([two, one, a_member, b_member], key=typ.sort_key_function), + [one, two, a_member, b_member], + ) + + def test_pep435_no_sort_key(self): + typ = Enum(self.SomeEnum, sort_key_function=None) + is_(typ.sort_key_function, None) + def test_pep435_enum_round_trip(self): stdlib_enum_table = self.tables["stdlib_enum_table"]