]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- dynamic relations, when referenced, create a strong
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 16 Jan 2008 22:06:15 +0000 (22:06 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 16 Jan 2008 22:06:15 +0000 (22:06 +0000)
reference to the parent object so that the query
still has a parent to call against even if the
parent is only created (and otherwise dereferenced)
within the scope of a single expression [ticket:938]

CHANGES
lib/sqlalchemy/orm/dynamic.py
test/orm/dynamic.py

diff --git a/CHANGES b/CHANGES
index b555603b7b0357e9825a66cc40168e36155f59cb..af2101e5da2ee1157175015b91b25a1521b3fbc8 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -13,6 +13,12 @@ CHANGES
     - proper error message is raised when trying to access
       expired instance attributes with no session present
 
+    - dynamic relations, when referenced, create a strong
+      reference to the parent object so that the query
+      still has a parent to call against even if the 
+      parent is only created (and otherwise dereferenced)
+      within the scope of a single expression [ticket:938]
+      
     - added a mapper() flag "eager_defaults"; when set to
       True, defaults that are generated during an INSERT or
       UPDATE operation are post-fetched immediately, instead
index fe781ab05e4bbb7d7afa44f8ccc81c8e2b26cf63..a5e1a84678de61af13df08eb98b91c9143eb88ae 100644 (file)
@@ -80,15 +80,14 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
 class AppenderQuery(Query):
     def __init__(self, attr, state):
         super(AppenderQuery, self).__init__(attr.target_mapper, None)
-        self.state = state
+        self.instance = state.obj()
         self.attr = attr
     
     def __session(self):
-        instance = self.state.obj()
-        sess = object_session(instance)
-        if sess is not None and self.autoflush and sess.autoflush and instance in sess:
+        sess = object_session(self.instance)
+        if sess is not None and self.autoflush and sess.autoflush and self.instance in sess:
             sess.flush()
-        if not has_identity(instance):
+        if not has_identity(self.instance):
             return None
         else:
             return sess
@@ -100,21 +99,21 @@ class AppenderQuery(Query):
     def __iter__(self):
         sess = self.__session()
         if sess is None:
-            return iter(self.attr._get_collection(self.state, passive=True).added_items)
+            return iter(self.attr._get_collection(self.instance._state, passive=True).added_items)
         else:
             return iter(self._clone(sess))
 
     def __getitem__(self, index):
         sess = self.__session()
         if sess is None:
-            return self.attr._get_collection(self.state, passive=True).added_items.__getitem__(index)
+            return self.attr._get_collection(self.instance._state, passive=True).added_items.__getitem__(index)
         else:
             return self._clone(sess).__getitem__(index)
     
     def count(self):
         sess = self.__session()
         if sess is None:
-            return len(self.attr._get_collection(self.state, passive=True).added_items)
+            return len(self.attr._get_collection(self.instance._state, passive=True).added_items)
         else:
             return self._clone(sess).count()
     
@@ -122,7 +121,7 @@ class AppenderQuery(Query):
         # note we're returning an entirely new Query class instance here
         # without any assignment capabilities;
         # the class of this query is determined by the session.
-        instance = self.state.obj()
+        instance = self.instance
         if sess is None:
             sess = object_session(instance)
             if sess is None:
@@ -134,19 +133,19 @@ class AppenderQuery(Query):
         return sess.query(self.attr.target_mapper).with_parent(instance, self.attr.key)
 
     def assign(self, collection):
-        instance = self.state.obj()
+        instance = self.instance
         if has_identity(instance):
             oldlist = list(self)
         else:
             oldlist = []
-        self.attr._get_collection(self.state, passive=True).replace(oldlist, collection)
+        self.attr._get_collection(self.instance._state, passive=True).replace(oldlist, collection)
         return oldlist
         
     def append(self, item):
-        self.attr.append(self.state, item, None)
+        self.attr.append(self.instance._state, item, None)
 
     def remove(self, item):
-        self.attr.remove(self.state, item, None)
+        self.attr.remove(self.instance._state, item, None)
 
             
 class CollectionHistory(object): 
index f188a478ff3e86b90545c4a597993fc64994324a..199eb474ffa789621aee557032e01502b5fb2dee 100644 (file)
@@ -213,5 +213,57 @@ for autoflush in (False, True):
     for saveuser in (False, True):
         create_backref_test(autoflush, saveuser)
 
+class DontDereferenceTest(ORMTest):
+    def define_tables(self, metadata):
+        global users_table, addresses_table
+        
+        users_table = Table('users', metadata,
+                           Column('id', Integer, primary_key=True),
+                           Column('name', String(40)),
+                           Column('fullname', String(100)),
+                           Column('password', String(15)))
+
+        addresses_table = Table('addresses', metadata,
+                                Column('id', Integer, primary_key=True),
+                                Column('email_address', String(100), nullable=False),
+                                Column('user_id', Integer, ForeignKey('users.id')))
+    def test_no_deref(self):
+        mapper(User, users_table, properties={
+            'addresses': relation(Address, backref='user', lazy='dynamic')
+            })
+
+        mapper(Address, addresses_table)
+
+        session = create_session()
+        user = User()
+        user.name = 'joe'
+        user.fullname = 'Joe User'
+        user.password = 'Joe\'s secret'
+        address = Address()
+        address.email_address = 'joe@joesdomain.example'
+        address.user = user
+        session.save(user)
+        session.flush()
+        session.clear()
+        
+        def query1():
+            session = create_session(metadata.bind)
+            user = session.query(User).first()
+            return user.addresses.all()
+
+        def query2():
+            session = create_session(metadata.bind)
+            return session.query(User).first().addresses.all()
+
+        def query3():
+            session = create_session(metadata.bind)
+            user = session.query(User).first()
+            return session.query(User).first().addresses.all()
+
+        self.assertEquals(query1(), [Address(email_address='joe@joesdomain.example')]  )
+        self.assertEquals(query2(), [Address(email_address='joe@joesdomain.example')]  )
+        self.assertEquals(query3(), [Address(email_address='joe@joesdomain.example')]  )
+        
+        
 if __name__ == '__main__':
     testenv.main()