--- /dev/null
+.. change::
+ :tags: usecase, orm
+ :tickets: 4285
+
+ Added support for the use of an :class:`.Enum` datatype using Python
+ pep-435 enumeration objects as values for use as a primary key column
+ mapped by the ORM. As these values are not inherently sortable, as
+ required by the ORM for primary keys, a new
+ :attr:`.TypeEngine.sort_key_function` attribute is added to the typing
+ system which allows any SQL type to implement a sorting for Python objects
+ of its type which is consulted by the unit of work. The :class:`.Enum`
+ type then defines this using the database value of a given enumeration.
+ The sorting scheme can be also be redefined by passing a callable to the
+ :paramref:`.Enum.sort_key_function` parameter. Pull request courtesy
+ Nicolas Caniart.
)
return identity_key[1]
+ @_memoized_configured_property
+ def _persistent_sortkey_fn(self):
+ key_fns = [col.type.sort_key_function for col in self.primary_key]
+
+ if set(key_fns).difference([None]):
+
+ def key(state):
+ return tuple(
+ key_fn(val) if key_fn is not None else val
+ for key_fn, val in zip(key_fns, state.key[1])
+ )
+
+ else:
+
+ def key(state):
+ return state.key[1]
+
+ return key
+
@_memoized_configured_property
def _identity_key_props(self):
return [self._columntoproperty[col] for col in self.primary_key]
# if batch=false, call _save_obj separately for each object
if not single and not base_mapper.batch:
- for state in _sort_states(states):
+ for state in _sort_states(base_mapper, states):
save_obj(base_mapper, [state], uowtransaction, single=True)
return
connection = uowtransaction.transaction.connection(base_mapper)
connection_callable = None
- for state in _sort_states(states):
+ for state in _sort_states(base_mapper, states):
if connection_callable:
connection = connection_callable(base_mapper, state.obj())
)
-def _sort_states(states):
+def _sort_states(mapper, states):
pending = set(states)
persistent = set(s for s in pending if s.key is not None)
pending.difference_update(persistent)
+
try:
- persistent_sorted = sorted(persistent, key=lambda q: q.key[1])
+ persistent_sorted = sorted(
+ persistent, key=mapper._persistent_sortkey_fn
+ )
except TypeError as err:
raise sa_exc.InvalidRequestError(
"Could not sort objects by primary key; primary key "
from . import operators
from . import type_api
from .base import _bind_or_error
+from .base import NO_ARG
from .base import SchemaEventTarget
from .elements import _defer_name
from .elements import _literal_as_binds
.. versionadded:: 1.2.3
+ :param sort_key_function: a Python callable which may be used as the
+ "key" argument in the Python ``sorted()`` built-in. The SQLAlchemy
+ ORM requires that primary key columns which are mapped must
+ be sortable in some way. When using an unsortable enumeration
+ object such as a Python 3 ``Enum`` object, this parameter may be
+ used to set a default sort key function for the objects. By
+ default, the database value of the enumeration is used as the
+ sorting function.
+
+ .. versionadded:: 1.3.8
+
+
+
"""
self._enum_init(enums, kw)
self.native_enum = kw.pop("native_enum", True)
self.create_constraint = kw.pop("create_constraint", True)
self.values_callable = kw.pop("values_callable", None)
+ self._sort_key_function = kw.pop("sort_key_function", NO_ARG)
values, objects = self._parse_into_values(enums, kw)
self._setup_for_values(values, objects, kw)
]
)
+ @property
+ def sort_key_function(self):
+ if self._sort_key_function is NO_ARG:
+ return self._db_value_for_elem
+ else:
+ return self._sort_key_function
+
@property
def native(self):
return self.native_enum
"""
+ sort_key_function = None
+ """A sorting function that can be passed as the key to sorted.
+
+ The default value of ``None`` indicates that the values stored by
+ this type are self-sorting.
+
+ .. versionadded:: 1.3.8
+
+ """
+
should_evaluate_none = False
"""If True, the Python constant ``None`` is considered to be handled
explicitly by this type.
"""
return self.impl.compare_values(x, y)
+ @property
+ def sort_key_function(self):
+ return self.impl.sort_key_function
+
def __repr__(self):
return util.generic_repr(self, to_inspect=self.impl)
states[4].insert_order = DontCompareMeToString(1)
states[2].insert_order = DontCompareMeToString(3)
eq_(
- _sort_states(states),
+ _sort_states(m, states),
[states[4], states[3], states[0], states[1], states[2]],
)
"""
+import itertools
+
import sqlalchemy as sa
+from sqlalchemy import bindparam
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import testing
+from sqlalchemy import TypeDecorator
from sqlalchemy.orm import create_session
from sqlalchemy.orm import mapper
from sqlalchemy.orm import relationship
)
+class UnsortablePKTest(fixtures.MappedTest):
+ """Test integration with TypeEngine.sort_key_function"""
+
+ class HashableDict(dict):
+ def __hash__(self):
+ return hash((self["x"], self["y"]))
+
+ @classmethod
+ def define_tables(cls, metadata):
+ class MyUnsortable(TypeDecorator):
+ impl = String(10)
+
+ def process_bind_param(self, value, dialect):
+ return "%s,%s" % (value["x"], value["y"])
+
+ def process_result_value(self, value, dialect):
+ rec = value.split(",")
+ return cls.HashableDict({"x": rec[0], "y": rec[1]})
+
+ def sort_key_function(self, value):
+ return (value["x"], value["y"])
+
+ Table(
+ "data",
+ metadata,
+ Column("info", MyUnsortable(), primary_key=True),
+ Column("int_value", Integer),
+ )
+
+ @classmethod
+ def setup_classes(cls):
+ class Data(cls.Comparable):
+ pass
+
+ @classmethod
+ def setup_mappers(cls):
+ mapper(cls.classes.Data, cls.tables.data)
+
+ def test_updates_sorted(self):
+ Data = self.classes.Data
+ s = Session()
+
+ s.add_all(
+ [
+ Data(info=self.HashableDict(x="a", y="b")),
+ Data(info=self.HashableDict(x="a", y="a")),
+ Data(info=self.HashableDict(x="b", y="b")),
+ Data(info=self.HashableDict(x="b", y="a")),
+ ]
+ )
+ s.commit()
+
+ aa, ab, ba, bb = s.query(Data).order_by(Data.info).all()
+
+ counter = itertools.count()
+ ab.int_value = bindparam(key=None, callable_=lambda: next(counter))
+ ba.int_value = bindparam(key=None, callable_=lambda: next(counter))
+ bb.int_value = bindparam(key=None, callable_=lambda: next(counter))
+ aa.int_value = bindparam(key=None, callable_=lambda: next(counter))
+
+ s.commit()
+
+ eq_(
+ s.query(Data.int_value).order_by(Data.info).all(),
+ [(0,), (1,), (2,), (3,)],
+ )
+
+
class JoinedInheritancePKOnFKTest(fixtures.MappedTest):
"""Test cascades of pk->non-pk/fk on joined table inh."""
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import testing
+from sqlalchemy.inspection import inspect
from sqlalchemy.orm import column_property
from sqlalchemy.orm import create_session
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm import mapper
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
+from sqlalchemy.orm.persistence import _sort_states
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
two = MySortableEnum("two", 2)
three = MyNotSortableEnum("three", 3)
four = MyNotSortableEnum("four", 4)
+ five = MyNotSortableEnum("five", 5)
@classmethod
def define_tables(cls, metadata):
Table(
"t2",
metadata,
- Column("id", Enum(cls.MyNotSortableEnum), primary_key=True),
+ Column(
+ "id",
+ Enum(cls.MyNotSortableEnum, sort_key_function=None),
+ primary_key=True,
+ ),
Column("data", String(10)),
)
+ Table(
+ "t3",
+ metadata,
+ Column("id", Enum(cls.MyNotSortableEnum), primary_key=True),
+ Column("value", Integer),
+ )
+
+ @staticmethod
+ def sort_enum_key_value(value):
+ return value.value
+
@classmethod
def setup_classes(cls):
class T1(cls.Basic):
class T2(cls.Basic):
pass
+ class T3(cls.Basic):
+ def __str__(self):
+ return "T3(id={})".format(self.id)
+
@classmethod
def setup_mappers(cls):
mapper(cls.classes.T1, cls.tables.t1)
mapper(cls.classes.T2, cls.tables.t2)
+ mapper(cls.classes.T3, cls.tables.t3)
def test_exception_persistent_flush_py3k(self):
s = Session()
a.data = "bar"
b.data = "foo"
s.commit()
+
+ def test_pep435_custom_sort_key(self):
+ s = Session()
+
+ a = self.classes.T3(id=self.three, value=1)
+ b = self.classes.T3(id=self.four, value=2)
+ s.add_all([a, b])
+ s.commit()
+
+ c = self.classes.T3(id=self.five, value=0)
+ s.add(c)
+
+ states = [o._sa_instance_state for o in [b, a, c]]
+ eq_(
+ _sort_states(inspect(self.classes.T3), states),
+ # pending come first, then "four" < "three"
+ [o._sa_instance_state for o in [c, b, a]],
+ )
[(1, "two"), (2, "two"), (3, "one")],
)
+ def test_pep435_default_sort_key(self):
+ one, two, a_member, b_member = (
+ self.one,
+ self.two,
+ self.a_member,
+ self.b_member,
+ )
+ typ = Enum(self.SomeEnum)
+
+ is_(typ.sort_key_function.__func__, typ._db_value_for_elem.__func__)
+
+ eq_(
+ sorted([two, one, a_member, b_member], key=typ.sort_key_function),
+ [a_member, b_member, one, two],
+ )
+
+ def test_pep435_custom_sort_key(self):
+ one, two, a_member, b_member = (
+ self.one,
+ self.two,
+ self.a_member,
+ self.b_member,
+ )
+
+ def sort_enum_key_value(value):
+ return str(value.value)
+
+ typ = Enum(self.SomeEnum, sort_key_function=sort_enum_key_value)
+ is_(typ.sort_key_function, sort_enum_key_value)
+
+ eq_(
+ sorted([two, one, a_member, b_member], key=typ.sort_key_function),
+ [one, two, a_member, b_member],
+ )
+
+ def test_pep435_no_sort_key(self):
+ typ = Enum(self.SomeEnum, sort_key_function=None)
+ is_(typ.sort_key_function, None)
+
def test_pep435_enum_round_trip(self):
stdlib_enum_table = self.tables["stdlib_enum_table"]