]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
dan's latest patch for session.identity_key()
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 22 Mar 2007 20:54:52 +0000 (20:54 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 22 Mar 2007 20:54:52 +0000 (20:54 +0000)
lib/sqlalchemy/orm/session.py
test/orm/session.py

index 72a1c4c85e9315de51017e41590343c5d4a027a4..51d850d3c170d7aab9bea65a286b57ff68171b83 100644 (file)
@@ -452,54 +452,57 @@ class Session(object):
 
             identity_key(class\_, ident, entity_name=None)
                 class\_
-                    mapped class
-                
+                    mapped class (must be a positional argument)
+
                 ident
                     primary key, if the key is composite this is a tuple
                 
                 entity_name
-                    optional entity name. May be given as a
-                    positional arg or as a keyword arg.
+                    optional entity name
 
             identity_key(instance=instance)
                 instance
                     object instance (must be given as a keyword arg)
 
-            identity_key(row=row, class=class\_, entity_name=None)
+            identity_key(class\_, row=row, entity_name=None)
+                class\_
+                    mapped class (must be a positional argument)
+                
                 row
                     result proxy row (must be given as a keyword arg)
-            
+
+                entity_name
+                    optional entity name (must be given as a keyword arg)
         """
         if args:
-            kw = {}
-            if len(args) == 2:
+            if len(args) == 1:
+                class_ = args[0]
+                try:
+                    row = kwargs.pop("row")
+                except KeyError:
+                    ident = kwargs.pop("ident")
+                entity_name = kwargs.pop("entity_name", None)
+            elif len(args) == 2:
                 class_, ident = args
                 entity_name = kwargs.pop("entity_name", None)
-                assert not kwargs, ("unknown keyword arguments: %s"
-                    % (kwargs.keys(),))
-            else:
-                assert len(args) == 3, ("two or three positional args are "
-                    "accepted, got %s" % len(args))
+            elif len(args) == 3:
                 class_, ident, entity_name = args
-            mapper = _class_mapper(class_, entity_name=entity_name)
-            return mapper.instance_key_from_primary_key(ident,
-                entity_name=entity_name)
-        else:
-            try:
-                instance = kwargs.pop("instance")
-            except KeyError:
-                row = kwargs.pop("row")
-                class_ = kwargs.pop("class")
-                entity_name = kwargs.pop("entity_name", None)
-                assert not kwargs, ("unknown keyword arguments: %s"
-                    % (kwargs.keys(),))
-                mapper = _class_mapper(class_, entity_name=entity_name)
-                return mapper.identity_key_from_row(row)
             else:
-                assert not kwargs, ("unknown keyword arguments: %s"
-                    % (kwargs.keys(),))
-                mapper = _object_mapper(instance)
-                return mapper.identity_key_from_instance(instance)
+                raise exceptions.ArgumentError("expected up to three "
+                    "positional arguments, got %s" % len(args))
+            if kwargs:
+                raise exceptions.ArgumentError("unknown keyword arguments: %s"
+                    % ", ".join(kwargs.keys()))
+            mapper = _class_mapper(class_, entity_name=entity_name)
+            if "ident" in locals():
+                return mapper.identity_key_from_primary_key(ident)
+            return mapper.identity_key_from_row(row)
+        instance = kwargs.pop("instance")
+        if kwargs:
+            raise exceptions.ArgumentError("unknown keyword arguments: %s"
+                % ", ".join(kwargs.keys()))
+        mapper = _object_mapper(instance)
+        return mapper.identity_key_from_instance(instance)
 
     def _save_impl(self, object, **kwargs):
         if hasattr(object, '_instance_key'):
index 960f09309dfc3a5706fe99ac8a0c63208a6c5b3a..705e56e96c627966b76c241665cab0e4fcc02d82 100644 (file)
@@ -176,6 +176,40 @@ class SessionTest(AssertMixin):
         assert s.query(Address).selectone().address_id == a.address_id
         assert s.query(User).selectfirst() is None
 
+    def _assert_key(self, got, expect):
+        assert got == expect, "expected %r got %r" % (expect, got)
+
+    def test_identity_key_1(self):
+        mapper(User, users)
+        mapper(User, users, entity_name="en")
+        s = create_session()
+        key = s.identity_key(User, 1)
+        self._assert_key(key, (User, (1,), None))
+        key = s.identity_key(User, 1, "en")
+        self._assert_key(key, (User, (1,), "en"))
+        key = s.identity_key(User, 1, entity_name="en")
+        self._assert_key(key, (User, (1,), "en"))
+        key = s.identity_key(User, ident=1, entity_name="en")
+        self._assert_key(key, (User, (1,), "en"))
+
+    def test_identity_key_2(self):
+        mapper(User, users)
+        s = create_session()
+        u = User()
+        s.save(u)
+        s.flush()
+        key = s.identity_key(instance=u)
+        self._assert_key(key, (User, (u.user_id,), None))
+
+    def test_identity_key_3(self):
+        mapper(User, users)
+        mapper(User, users, entity_name="en")
+        s = create_session()
+        row = {users.c.user_id: 1, users.c.user_name: "Frank"}
+        key = s.identity_key(User, row=row)
+        self._assert_key(key, (User, (1,), None))
+        key = s.identity_key(User, row=row, entity_name="en")
+        self._assert_key(key, (User, (1,), "en"))
         
 class OrphanDeletionTest(AssertMixin):