]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Backport of "optimized get" fix from 0.7,
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 5 Dec 2010 03:13:10 +0000 (22:13 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 5 Dec 2010 03:13:10 +0000 (22:13 -0500)
improves the generation of joined-inheritance
"load expired row" behavior.  [ticket:1992]

CHANGES
lib/sqlalchemy/orm/mapper.py
test/orm/inheritance/test_basic.py
test/orm/test_expire.py

diff --git a/CHANGES b/CHANGES
index d1115e7a641352f6b06a62480434f7885e2216db..597dfb08fd78e0537fe89db0926ab99e8a3d5ae1 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -47,7 +47,11 @@ CHANGES
   - Query.get() will raise if the number of params
     in a composite key is too large, as well as too 
     small. [ticket:1977]
-    
+
+  - Backport of "optimized get" fix from 0.7,
+    improves the generation of joined-inheritance
+    "load expired row" behavior.  [ticket:1992]
+
 - sql
   - Fixed operator precedence rules for multiple
     chains of a single non-associative operator.
index c1045226cdca2ca86dacf58f5b2cfccc1241edf1..4ed0c62bc061df9b88d6828949c1296cca4a92da 100644 (file)
@@ -1322,10 +1322,10 @@ class Mapper(object):
         """
         props = self._props
         
-        tables = set(chain(*
-                        (sqlutil.find_tables(props[key].columns[0],
-                        check_columns=True) 
-                        for key in attribute_names)
+        tables = set(chain(
+                        *[sqlutil.find_tables(c, check_columns=True) 
+                        for key in attribute_names
+                        for c in props[key].columns]
                     ))
         
         if self.base_mapper.local_table in tables:
@@ -1344,7 +1344,7 @@ class Mapper(object):
                 leftval = self._get_committed_state_attr_by_column(
                                                     state, state.dict, 
                                                     leftcol, passive=True)
-                if leftval is attributes.PASSIVE_NO_RESULT:
+                if leftval is attributes.PASSIVE_NO_RESULT or leftval is None:
                     raise ColumnsNotAvailable()
                 binary.left = sql.bindparam(None, leftval,
                                             type_=binary.right.type)
@@ -1352,7 +1352,7 @@ class Mapper(object):
                 rightval = self._get_committed_state_attr_by_column(
                                                     state, state.dict, 
                                                     rightcol, passive=True)
-                if rightval is attributes.PASSIVE_NO_RESULT:
+                if rightval is attributes.PASSIVE_NO_RESULT or rightval is None:
                     raise ColumnsNotAvailable()
                 binary.right = sql.bindparam(None, rightval,
                                             type_=binary.right.type)
@@ -2442,6 +2442,7 @@ def _load_scalar_attributes(state, attribute_names):
     has_key = state.has_identity
     
     result = False
+
     if mapper.inherits and not mapper.concrete:
         statement = mapper._optimized_get_statement(state, attribute_names)
         if statement is not None:
index cc7dcba052f2fe627f16b85eaa713b3bb8563808..ced0104c552c114d571291f382fcfedf69d1845d 100644 (file)
@@ -3,7 +3,8 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 from sqlalchemy import *
 from sqlalchemy import exc as sa_exc, util
 from sqlalchemy.orm import *
-from sqlalchemy.orm import exc as orm_exc
+from sqlalchemy.orm import exc as orm_exc, attributes
+from sqlalchemy.test.assertsql import AllOf, CompiledSQL
 
 from sqlalchemy.test import testing, engines
 from sqlalchemy.util import function_named
@@ -1115,22 +1116,29 @@ class OptimizedLoadTest(_base.MappedTest):
     
     @classmethod
     def define_tables(cls, metadata):
-        global base, sub, with_comp
-        base = Table('base', metadata,
+        Table('base', metadata,
             Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('data', String(50)),
-            Column('type', String(50))
+            Column('type', String(50)),
+            Column('counter', Integer, server_default="1")
         )
-        sub = Table('sub', metadata, 
+        Table('sub', metadata, 
             Column('id', Integer, ForeignKey('base.id'), primary_key=True),
-            Column('sub', String(50))
+            Column('sub', String(50)),
+            Column('counter', Integer, server_default="1"),
+            Column('counter2', Integer, server_default="1")
+        )
+        Table('subsub', metadata,
+            Column('id', Integer, ForeignKey('sub.id'), primary_key=True),
+            Column('counter2', Integer, server_default="1")
         )
-        with_comp = Table('with_comp', metadata,
+        Table('with_comp', metadata,
             Column('id', Integer, ForeignKey('base.id'), primary_key=True),
             Column('a', String(10)),
             Column('b', String(10))
         )
     
+    @testing.resolve_artifact_names
     def test_optimized_passes(self):
         """"test that the 'optimized load' routine doesn't crash when 
         a column in the join condition is not available."""
@@ -1160,6 +1168,7 @@ class OptimizedLoadTest(_base.MappedTest):
         s1 = sess.query(Base).first()
         assert s1.sub == 's1sub'
 
+    @testing.resolve_artifact_names
     def test_column_expression(self):
         class Base(_base.BasicEntity):
             pass
@@ -1177,6 +1186,7 @@ class OptimizedLoadTest(_base.MappedTest):
         s1 = sess.query(Base).first()
         assert s1.concat == 's1sub|s1sub'
 
+    @testing.resolve_artifact_names
     def test_column_expression_joined(self):
         class Base(_base.ComparableEntity):
             pass
@@ -1206,6 +1216,7 @@ class OptimizedLoadTest(_base.MappedTest):
             ]
         )
 
+    @testing.resolve_artifact_names
     def test_composite_column_joined(self):
         class Base(_base.ComparableEntity):
             pass
@@ -1234,17 +1245,118 @@ class OptimizedLoadTest(_base.MappedTest):
         assert s2test.comp
         eq_(s1test.comp, Comp('ham', 'cheese'))
         eq_(s2test.comp, Comp('bacon', 'eggs'))
+    
+    @testing.resolve_artifact_names
+    def test_load_expired_on_pending(self):
+        class Base(_base.ComparableEntity):
+            pass
+        class Sub(Base):
+            pass
+        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base')
+        mapper(Sub, sub, inherits=Base, polymorphic_identity='sub')
+        sess = Session()
+        s1 = Sub(data='s1')
+        sess.add(s1)
+        self.assert_sql_execution(
+                testing.db,
+                sess.flush,
+                CompiledSQL(
+                    "INSERT INTO base (data, type) VALUES (:data, :type)",
+                    [{'data':'s1','type':'sub'}]
+                ),
+                CompiledSQL(
+                    "SELECT base.counter AS base_counter, "
+                    "sub.counter AS sub_counter FROM base JOIN sub ON base.id = "
+                    "sub.id WHERE base.id = :param_1",
+                    lambda ctx:{'param_1':s1.id}
+                ),
+                CompiledSQL(
+                    "INSERT INTO sub (id, sub) VALUES (:id, :sub)",
+                    lambda ctx:{'id':s1.id, 'sub':None}
+                ),
+        )
+
+    @testing.resolve_artifact_names
+    def test_dont_generate_on_none(self):
+        class Base(_base.ComparableEntity):
+            pass
+        class Sub(Base):
+            pass
+        mapper(Base, base, polymorphic_on=base.c.type, 
+                            polymorphic_identity='base')
+        m = mapper(Sub, sub, inherits=Base, polymorphic_identity='sub')
+        
+        s1 = Sub()
+        assert m._optimized_get_statement(attributes.instance_state(s1), 
+                                ['counter2']) is None
+        
+        # loads s1.id as None
+        eq_(s1.id, None)
+        
+        # this now will come up with a value of None for id - should reject
+        assert m._optimized_get_statement(attributes.instance_state(s1), 
+                                ['counter2']) is None
+        
+        s1.id = 1
+        attributes.instance_state(s1).commit_all(s1.__dict__, None)
+        assert m._optimized_get_statement(attributes.instance_state(s1), 
+                                ['counter2']) is not None
+        
+    @testing.resolve_artifact_names
+    def test_load_expired_on_pending_twolevel(self):
+        class Base(_base.ComparableEntity):
+            pass
+        class Sub(Base):
+            pass
+        class SubSub(Sub):
+            pass
+            
+        mapper(Base, base, polymorphic_on=base.c.type, 
+                    polymorphic_identity='base')
+        mapper(Sub, sub, inherits=Base, polymorphic_identity='sub')
+        mapper(SubSub, subsub, inherits=Sub, polymorphic_identity='subsub')
+        sess = Session()
+        s1 = SubSub(data='s1', counter=1)
+        sess.add(s1)
+        self.assert_sql_execution(
+                testing.db,
+                sess.flush,
+                CompiledSQL(
+                    "INSERT INTO base (data, type, counter) VALUES "
+                    "(:data, :type, :counter)",
+                    [{'data':'s1','type':'subsub','counter':1}]
+                ),
+                CompiledSQL(
+                    "INSERT INTO sub (id, sub, counter) VALUES "
+                    "(:id, :sub, :counter)",
+                    lambda ctx:[{'counter': 1, 'sub': None, 'id': s1.id}]
+                ),
+                CompiledSQL(
+                    "SELECT sub.counter2 AS sub_counter2, "
+                    "subsub.counter2 AS subsub_counter2 FROM base "
+                    "JOIN sub ON base.id = sub.id JOIN "
+                    "subsub ON sub.id = subsub.id WHERE base.id = :param_1",
+                    lambda ctx:{u'param_1': s1.id}
+                ),
+                CompiledSQL(
+                    "INSERT INTO subsub (id) VALUES (:id)",
+                    lambda ctx:{'id':s1.id}
+                ),
+        )
+        
         
         
 class PKDiscriminatorTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         parents = Table('parents', metadata,
-                           Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+                           Column('id', Integer, primary_key=True, 
+                                    test_needs_autoincrement=True),
                            Column('name', String(60)))
                            
         children = Table('children', metadata,
-                        Column('id', Integer, ForeignKey('parents.id'), primary_key=True),
+                        Column('id', Integer, ForeignKey('parents.id'), 
+                                    primary_key=True),
                         Column('type', Integer,primary_key=True),
                         Column('name', String(60)))
 
index c8bdf1719e5f9d2b9b660cd1bb60228f3a391d69..f356658a1ad4198b942d531b53c2190185bb257f 100644 (file)
@@ -9,7 +9,7 @@ from sqlalchemy.test.schema import Table
 from sqlalchemy.test.schema import Column
 from sqlalchemy.orm import mapper, relationship, create_session, \
                         attributes, deferred, exc as orm_exc, defer, undefer,\
-                        strategies, state, lazyload, backref
+                        strategies, state, lazyload, backref, Session
 from test.orm import _base, _fixtures
 
 
@@ -240,7 +240,7 @@ class ExpireTest(_fixtures.FixtureTest):
         assert 'name' not in u.__dict__
         sess.add(u)
         assert_raises(sa_exc.InvalidRequestError, getattr, u, 'name')
-        
+
         
     @testing.resolve_artifact_names
     def test_expire_preserves_changes(self):
@@ -886,7 +886,8 @@ class PolymorphicExpireTest(_base.MappedTest):
            Column('person_id', Integer, primary_key=True,
                   test_needs_autoincrement=True),
            Column('name', String(50)),
-           Column('type', String(30)))
+           Column('type', String(30)),
+           )
 
         engineers = Table('engineers', metadata,
            Column('person_id', Integer, ForeignKey('people.person_id'),
@@ -913,11 +914,15 @@ class PolymorphicExpireTest(_base.MappedTest):
             {'person_id':2, 'status':'new engineer'},
             {'person_id':3, 'status':'old engineer'},
         )
-
+    
+    @classmethod
     @testing.resolve_artifact_names
-    def test_poly_deferred(self):
+    def setup_mappers(cls):
         mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
         mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer')
+        
+    @testing.resolve_artifact_names
+    def test_poly_deferred(self):
 
         sess = create_session()
         [p1, e1, e2] = sess.query(Person).order_by(people.c.person_id).all()
@@ -953,6 +958,34 @@ class PolymorphicExpireTest(_base.MappedTest):
         self.assert_sql_count(testing.db, go, 2)
         eq_(Engineer.name.get_history(e1), (['new engineer name'],(), ['engineer1']))
 
+    @testing.resolve_artifact_names
+    def test_no_instance_key(self):
+        
+        sess = create_session()
+        e1 = sess.query(Engineer).get(2)
+
+        sess.expire(e1, attribute_names=['name'])
+        sess.expunge(e1)
+        attributes.instance_state(e1).key = None
+        assert 'name' not in e1.__dict__
+        sess.add(e1)
+        assert e1.name == 'engineer1'
+    
+    @testing.resolve_artifact_names
+    def test_no_instance_key(self):
+        # same as test_no_instance_key, but the PK columns
+        # are absent.  ensure an error is raised.
+        sess = create_session()
+        e1 = sess.query(Engineer).get(2)
+
+        sess.expire(e1, attribute_names=['name', 'person_id'])
+        sess.expunge(e1)
+        attributes.instance_state(e1).key = None
+        assert 'name' not in e1.__dict__
+        sess.add(e1)
+        assert_raises(sa_exc.InvalidRequestError, getattr, e1, 'name')
+
+
 class ExpiredPendingTest(_fixtures.FixtureTest):
     run_define_tables = 'once'
     run_setup_classes = 'once'