From eac7ca356aebe7b2c43b889aff01cd952018ffe6 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 22 Mar 2007 20:54:52 +0000 Subject: [PATCH] dan's latest patch for session.identity_key() --- lib/sqlalchemy/orm/session.py | 65 ++++++++++++++++++----------------- test/orm/session.py | 34 ++++++++++++++++++ 2 files changed, 68 insertions(+), 31 deletions(-) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 72a1c4c85e..51d850d3c1 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -452,54 +452,57 @@ class Session(object): identity_key(class\_, ident, entity_name=None) class\_ - mapped class - + mapped class (must be a positional argument) + ident primary key, if the key is composite this is a tuple entity_name - optional entity name. May be given as a - positional arg or as a keyword arg. + optional entity name identity_key(instance=instance) instance object instance (must be given as a keyword arg) - identity_key(row=row, class=class\_, entity_name=None) + identity_key(class\_, row=row, entity_name=None) + class\_ + mapped class (must be a positional argument) + row result proxy row (must be given as a keyword arg) - + + entity_name + optional entity name (must be given as a keyword arg) """ if args: - kw = {} - if len(args) == 2: + if len(args) == 1: + class_ = args[0] + try: + row = kwargs.pop("row") + except KeyError: + ident = kwargs.pop("ident") + entity_name = kwargs.pop("entity_name", None) + elif len(args) == 2: class_, ident = args entity_name = kwargs.pop("entity_name", None) - assert not kwargs, ("unknown keyword arguments: %s" - % (kwargs.keys(),)) - else: - assert len(args) == 3, ("two or three positional args are " - "accepted, got %s" % len(args)) + elif len(args) == 3: class_, ident, entity_name = args - mapper = _class_mapper(class_, entity_name=entity_name) - return mapper.instance_key_from_primary_key(ident, - entity_name=entity_name) - else: - try: - instance = kwargs.pop("instance") - except KeyError: - row = kwargs.pop("row") - class_ = kwargs.pop("class") - entity_name = kwargs.pop("entity_name", None) - assert not kwargs, ("unknown keyword arguments: %s" - % (kwargs.keys(),)) - mapper = _class_mapper(class_, entity_name=entity_name) - return mapper.identity_key_from_row(row) else: - assert not kwargs, ("unknown keyword arguments: %s" - % (kwargs.keys(),)) - mapper = _object_mapper(instance) - return mapper.identity_key_from_instance(instance) + raise exceptions.ArgumentError("expected up to three " + "positional arguments, got %s" % len(args)) + if kwargs: + raise exceptions.ArgumentError("unknown keyword arguments: %s" + % ", ".join(kwargs.keys())) + mapper = _class_mapper(class_, entity_name=entity_name) + if "ident" in locals(): + return mapper.identity_key_from_primary_key(ident) + return mapper.identity_key_from_row(row) + instance = kwargs.pop("instance") + if kwargs: + raise exceptions.ArgumentError("unknown keyword arguments: %s" + % ", ".join(kwargs.keys())) + mapper = _object_mapper(instance) + return mapper.identity_key_from_instance(instance) def _save_impl(self, object, **kwargs): if hasattr(object, '_instance_key'): diff --git a/test/orm/session.py b/test/orm/session.py index 960f09309d..705e56e96c 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -176,6 +176,40 @@ class SessionTest(AssertMixin): assert s.query(Address).selectone().address_id == a.address_id assert s.query(User).selectfirst() is None + def _assert_key(self, got, expect): + assert got == expect, "expected %r got %r" % (expect, got) + + def test_identity_key_1(self): + mapper(User, users) + mapper(User, users, entity_name="en") + s = create_session() + key = s.identity_key(User, 1) + self._assert_key(key, (User, (1,), None)) + key = s.identity_key(User, 1, "en") + self._assert_key(key, (User, (1,), "en")) + key = s.identity_key(User, 1, entity_name="en") + self._assert_key(key, (User, (1,), "en")) + key = s.identity_key(User, ident=1, entity_name="en") + self._assert_key(key, (User, (1,), "en")) + + def test_identity_key_2(self): + mapper(User, users) + s = create_session() + u = User() + s.save(u) + s.flush() + key = s.identity_key(instance=u) + self._assert_key(key, (User, (u.user_id,), None)) + + def test_identity_key_3(self): + mapper(User, users) + mapper(User, users, entity_name="en") + s = create_session() + row = {users.c.user_id: 1, users.c.user_name: "Frank"} + key = s.identity_key(User, row=row) + self._assert_key(key, (User, (1,), None)) + key = s.identity_key(User, row=row, entity_name="en") + self._assert_key(key, (User, (1,), "en")) class OrphanDeletionTest(AssertMixin): -- 2.47.2