]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- "Passive defaults" and other "inline" defaults can now
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 28 Jan 2008 23:15:40 +0000 (23:15 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 28 Jan 2008 23:15:40 +0000 (23:15 +0000)
be loaded during a flush() call if needed; in particular,
this allows constructing relations() where a foreign key
column references a server-side-generated, non-primary-key
column. [ticket:954]

CHANGES
lib/sqlalchemy/orm/mapper.py
test/orm/expire.py
test/orm/unitofwork.py

diff --git a/CHANGES b/CHANGES
index 069192a445d247b9bb9627890546873a85deec29..2359c3508382e4cb286d0920ffedda58ac33f5cd 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -95,7 +95,13 @@ CHANGES
 
     - Fixed a rather expensive call in Query that was slowing
       down polymorphic queries.
-
+    
+    - "Passive defaults" and other "inline" defaults can now
+      be loaded during a flush() call if needed; in particular,
+      this allows constructing relations() where a foreign key
+      column references a server-side-generated, non-primary-key
+      column. [ticket:954]
+      
     - Every begin() must now be accompanied by a corresponding
       commit() or rollback(), including the implicit begin()
       in transactional sessions.
index 959cf274c4553fe45b5a3fa364507b121eefcbe9..1489ef65e16f0a33acfeb4044357c7f9a3171423 100644 (file)
@@ -1298,7 +1298,13 @@ class Mapper(object):
 
         # determine identity key
         if refresh_instance:
-            identitykey = refresh_instance.dict['_instance_key']
+            try:
+                identitykey = refresh_instance.dict['_instance_key']
+            except KeyError:
+                # super-rare condition; a refresh is being called 
+                # on a non-instance-key instance; this is meant to only 
+                # occur wihtin a flush()
+                identitykey = self._identity_key_from_state(refresh_instance)
         else:
             identitykey = self.identity_key_from_row(row)
 
@@ -1550,8 +1556,13 @@ def _load_scalar_attributes(instance, attribute_names):
             session = mapper.get_session()
         except exceptions.InvalidRequestError:
             raise exceptions.InvalidRequestError("Instance %s is not bound to a Session, and no contextual session is established; attribute refresh operation cannot proceed" % (instance.__class__))
-
-    if session.query(mapper)._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None:
+    
+    state = instance._state
+    if '_instance_key' in state.dict:
+        identity_key = state.dict['_instance_key']
+    else:
+        identity_key = mapper._identity_key_from_state(state)
+    if session.query(mapper)._get(identity_key, refresh_instance=state, only_load_props=attribute_names) is None:
         raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % instance_str(instance))
 
 def _state_mapper(state, entity_name=None):
index 98963c0128dd99fc0dd0bd354f2d15aba33468fa..8a01017e41b876ea3c33b4e500be0bf08ced5862 100644 (file)
@@ -75,7 +75,23 @@ class ExpireTest(FixtureTest):
             u.name
         except exceptions.InvalidRequestError, e:
             assert str(e) == "Instance <class 'testlib.fixtures.User'> is not bound to a Session, and no contextual session is established; attribute refresh operation cannot proceed"
+    
+    def test_no_instance_key(self):
+        # this tests an artificial condition such that 
+        # an instance is pending, but has expired attributes.  this
+        # is actually part of a larger behavior when postfetch needs to 
+        # occur during a flush() on an instance that was just inserted
+        mapper(User, users)
+        sess = create_session()
+        u = sess.query(User).get(7)
 
+        sess.expire(u, attribute_names=['name'])
+        sess.expunge(u)
+        del u._instance_key
+        assert 'name' not in u.__dict__
+        sess.save(u)
+        assert u.name == 'jack'
+        
     def test_expire_preserves_changes(self):
         """test that the expire load operation doesn't revert post-expire changes"""
 
index bcd4d217598f4b1705901bcafb71cccb490a9d6b..92d526f789fe363bd75e583062592d5254a9d154 100644 (file)
@@ -681,7 +681,7 @@ class DefaultTest(ORMTest):
 
     def define_tables(self, metadata):
         db = testing.db
-        use_string_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle') or db.engine.__module__.endswith('sqlite')
+        use_string_defaults = testing.against('postgres', 'oracle', 'sqlite') 
         global hohoval, althohoval
 
         if use_string_defaults:
@@ -693,14 +693,24 @@ class DefaultTest(ORMTest):
             hohoval = 9
             althohoval = 15
 
-        global default_table
+        global default_table, secondary_table
         default_table = Table('default_test', metadata,
-        Column('id', Integer, Sequence("dt_seq", optional=True), primary_key=True),
-        Column('hoho', hohotype, PassiveDefault(str(hohoval))),
-        Column('counter', Integer, default=func.length("1234567")),
-        Column('foober', String(30), default="im foober", onupdate="im the update")
+            Column('id', Integer, Sequence("dt_seq", optional=True), primary_key=True),
+            Column('hoho', hohotype, PassiveDefault(str(hohoval))),
+            Column('counter', Integer, default=func.length("1234567")),
+            Column('foober', String(30), default="im foober", onupdate="im the update"),
         )
-
+        
+        secondary_table = Table('secondary_table', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('data', String(50))
+            )
+        
+        if testing.against('postgres', 'oracle'):
+            default_table.append_column(Column('secondary_id', Integer, Sequence('sec_id_seq'), unique=True))
+            secondary_table.append_column(Column('fk_val', Integer, ForeignKey('default_test.secondary_id')))
+        else:
+            secondary_table.append_column(Column('hoho', hohotype, ForeignKey('default_test.hoho')))
 
     def test_insert(self):
         class Hoho(object):pass
@@ -778,7 +788,38 @@ class DefaultTest(ORMTest):
         h1.counter = 19
         Session.commit()
         self.assertEquals(h1.foober, 'im the update')
+    
+    def test_used_in_relation(self):
+        """test that a server-side generated default can be used as the target of a foreign key"""
+        
+        class Hoho(fixtures.Base):
+            pass
+        class Secondary(fixtures.Base):
+            pass
+        mapper(Hoho, default_table, properties={
+            'secondaries':relation(Secondary)
+        }, save_on_init=False)
+        
+        mapper(Secondary, secondary_table, save_on_init=False)
+        h1 = Hoho()
+        s1 = Secondary(data='s1')
+        h1.secondaries.append(s1)
+        Session.save(h1)
+        Session.commit()
+        Session.clear()
+        
+        self.assertEquals(Session.query(Hoho).get(h1.id), Hoho(hoho=hohoval, secondaries=[Secondary(data='s1')]))
+        
+        h1 = Session.query(Hoho).get(h1.id)
+        h1.secondaries.append(Secondary(data='s2'))
+        Session.commit()
+        Session.clear()
 
+        self.assertEquals(Session.query(Hoho).get(h1.id), 
+            Hoho(hoho=hohoval, secondaries=[Secondary(data='s1'), Secondary(data='s2')])
+        )
+        
+            
 class OneToManyTest(ORMTest):
     metadata = tables.metadata