From: Mike Bayer Date: Sun, 5 Dec 2010 03:13:10 +0000 (-0500) Subject: - Backport of "optimized get" fix from 0.7, X-Git-Tag: rel_0_6_6~31^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c8f8124cc9f42d9b3422cc4a0cfe8c6807237c1a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Backport of "optimized get" fix from 0.7, improves the generation of joined-inheritance "load expired row" behavior. [ticket:1992] --- diff --git a/CHANGES b/CHANGES index d1115e7a64..597dfb08fd 100644 --- a/CHANGES +++ b/CHANGES @@ -47,7 +47,11 @@ CHANGES - Query.get() will raise if the number of params in a composite key is too large, as well as too small. [ticket:1977] - + + - Backport of "optimized get" fix from 0.7, + improves the generation of joined-inheritance + "load expired row" behavior. [ticket:1992] + - sql - Fixed operator precedence rules for multiple chains of a single non-associative operator. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index c1045226cd..4ed0c62bc0 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1322,10 +1322,10 @@ class Mapper(object): """ props = self._props - tables = set(chain(* - (sqlutil.find_tables(props[key].columns[0], - check_columns=True) - for key in attribute_names) + tables = set(chain( + *[sqlutil.find_tables(c, check_columns=True) + for key in attribute_names + for c in props[key].columns] )) if self.base_mapper.local_table in tables: @@ -1344,7 +1344,7 @@ class Mapper(object): leftval = self._get_committed_state_attr_by_column( state, state.dict, leftcol, passive=True) - if leftval is attributes.PASSIVE_NO_RESULT: + if leftval is attributes.PASSIVE_NO_RESULT or leftval is None: raise ColumnsNotAvailable() binary.left = sql.bindparam(None, leftval, type_=binary.right.type) @@ -1352,7 +1352,7 @@ class Mapper(object): rightval = self._get_committed_state_attr_by_column( state, state.dict, rightcol, passive=True) - if rightval is attributes.PASSIVE_NO_RESULT: + if rightval is attributes.PASSIVE_NO_RESULT or rightval is None: raise ColumnsNotAvailable() binary.right = sql.bindparam(None, rightval, type_=binary.right.type) @@ -2442,6 +2442,7 @@ def _load_scalar_attributes(state, attribute_names): has_key = state.has_identity result = False + if mapper.inherits and not mapper.concrete: statement = mapper._optimized_get_statement(state, attribute_names) if statement is not None: diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index cc7dcba052..ced0104c55 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -3,7 +3,8 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy import * from sqlalchemy import exc as sa_exc, util from sqlalchemy.orm import * -from sqlalchemy.orm import exc as orm_exc +from sqlalchemy.orm import exc as orm_exc, attributes +from sqlalchemy.test.assertsql import AllOf, CompiledSQL from sqlalchemy.test import testing, engines from sqlalchemy.util import function_named @@ -1115,22 +1116,29 @@ class OptimizedLoadTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): - global base, sub, with_comp - base = Table('base', metadata, + Table('base', metadata, Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('data', String(50)), - Column('type', String(50)) + Column('type', String(50)), + Column('counter', Integer, server_default="1") ) - sub = Table('sub', metadata, + Table('sub', metadata, Column('id', Integer, ForeignKey('base.id'), primary_key=True), - Column('sub', String(50)) + Column('sub', String(50)), + Column('counter', Integer, server_default="1"), + Column('counter2', Integer, server_default="1") + ) + Table('subsub', metadata, + Column('id', Integer, ForeignKey('sub.id'), primary_key=True), + Column('counter2', Integer, server_default="1") ) - with_comp = Table('with_comp', metadata, + Table('with_comp', metadata, Column('id', Integer, ForeignKey('base.id'), primary_key=True), Column('a', String(10)), Column('b', String(10)) ) + @testing.resolve_artifact_names def test_optimized_passes(self): """"test that the 'optimized load' routine doesn't crash when a column in the join condition is not available.""" @@ -1160,6 +1168,7 @@ class OptimizedLoadTest(_base.MappedTest): s1 = sess.query(Base).first() assert s1.sub == 's1sub' + @testing.resolve_artifact_names def test_column_expression(self): class Base(_base.BasicEntity): pass @@ -1177,6 +1186,7 @@ class OptimizedLoadTest(_base.MappedTest): s1 = sess.query(Base).first() assert s1.concat == 's1sub|s1sub' + @testing.resolve_artifact_names def test_column_expression_joined(self): class Base(_base.ComparableEntity): pass @@ -1206,6 +1216,7 @@ class OptimizedLoadTest(_base.MappedTest): ] ) + @testing.resolve_artifact_names def test_composite_column_joined(self): class Base(_base.ComparableEntity): pass @@ -1234,17 +1245,118 @@ class OptimizedLoadTest(_base.MappedTest): assert s2test.comp eq_(s1test.comp, Comp('ham', 'cheese')) eq_(s2test.comp, Comp('bacon', 'eggs')) + + @testing.resolve_artifact_names + def test_load_expired_on_pending(self): + class Base(_base.ComparableEntity): + pass + class Sub(Base): + pass + mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='base') + mapper(Sub, sub, inherits=Base, polymorphic_identity='sub') + sess = Session() + s1 = Sub(data='s1') + sess.add(s1) + self.assert_sql_execution( + testing.db, + sess.flush, + CompiledSQL( + "INSERT INTO base (data, type) VALUES (:data, :type)", + [{'data':'s1','type':'sub'}] + ), + CompiledSQL( + "SELECT base.counter AS base_counter, " + "sub.counter AS sub_counter FROM base JOIN sub ON base.id = " + "sub.id WHERE base.id = :param_1", + lambda ctx:{'param_1':s1.id} + ), + CompiledSQL( + "INSERT INTO sub (id, sub) VALUES (:id, :sub)", + lambda ctx:{'id':s1.id, 'sub':None} + ), + ) + + @testing.resolve_artifact_names + def test_dont_generate_on_none(self): + class Base(_base.ComparableEntity): + pass + class Sub(Base): + pass + mapper(Base, base, polymorphic_on=base.c.type, + polymorphic_identity='base') + m = mapper(Sub, sub, inherits=Base, polymorphic_identity='sub') + + s1 = Sub() + assert m._optimized_get_statement(attributes.instance_state(s1), + ['counter2']) is None + + # loads s1.id as None + eq_(s1.id, None) + + # this now will come up with a value of None for id - should reject + assert m._optimized_get_statement(attributes.instance_state(s1), + ['counter2']) is None + + s1.id = 1 + attributes.instance_state(s1).commit_all(s1.__dict__, None) + assert m._optimized_get_statement(attributes.instance_state(s1), + ['counter2']) is not None + + @testing.resolve_artifact_names + def test_load_expired_on_pending_twolevel(self): + class Base(_base.ComparableEntity): + pass + class Sub(Base): + pass + class SubSub(Sub): + pass + + mapper(Base, base, polymorphic_on=base.c.type, + polymorphic_identity='base') + mapper(Sub, sub, inherits=Base, polymorphic_identity='sub') + mapper(SubSub, subsub, inherits=Sub, polymorphic_identity='subsub') + sess = Session() + s1 = SubSub(data='s1', counter=1) + sess.add(s1) + self.assert_sql_execution( + testing.db, + sess.flush, + CompiledSQL( + "INSERT INTO base (data, type, counter) VALUES " + "(:data, :type, :counter)", + [{'data':'s1','type':'subsub','counter':1}] + ), + CompiledSQL( + "INSERT INTO sub (id, sub, counter) VALUES " + "(:id, :sub, :counter)", + lambda ctx:[{'counter': 1, 'sub': None, 'id': s1.id}] + ), + CompiledSQL( + "SELECT sub.counter2 AS sub_counter2, " + "subsub.counter2 AS subsub_counter2 FROM base " + "JOIN sub ON base.id = sub.id JOIN " + "subsub ON sub.id = subsub.id WHERE base.id = :param_1", + lambda ctx:{u'param_1': s1.id} + ), + CompiledSQL( + "INSERT INTO subsub (id) VALUES (:id)", + lambda ctx:{'id':s1.id} + ), + ) + class PKDiscriminatorTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): parents = Table('parents', metadata, - Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), Column('name', String(60))) children = Table('children', metadata, - Column('id', Integer, ForeignKey('parents.id'), primary_key=True), + Column('id', Integer, ForeignKey('parents.id'), + primary_key=True), Column('type', Integer,primary_key=True), Column('name', String(60))) diff --git a/test/orm/test_expire.py b/test/orm/test_expire.py index c8bdf1719e..f356658a1a 100644 --- a/test/orm/test_expire.py +++ b/test/orm/test_expire.py @@ -9,7 +9,7 @@ from sqlalchemy.test.schema import Table from sqlalchemy.test.schema import Column from sqlalchemy.orm import mapper, relationship, create_session, \ attributes, deferred, exc as orm_exc, defer, undefer,\ - strategies, state, lazyload, backref + strategies, state, lazyload, backref, Session from test.orm import _base, _fixtures @@ -240,7 +240,7 @@ class ExpireTest(_fixtures.FixtureTest): assert 'name' not in u.__dict__ sess.add(u) assert_raises(sa_exc.InvalidRequestError, getattr, u, 'name') - + @testing.resolve_artifact_names def test_expire_preserves_changes(self): @@ -886,7 +886,8 @@ class PolymorphicExpireTest(_base.MappedTest): Column('person_id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(50)), - Column('type', String(30))) + Column('type', String(30)), + ) engineers = Table('engineers', metadata, Column('person_id', Integer, ForeignKey('people.person_id'), @@ -913,11 +914,15 @@ class PolymorphicExpireTest(_base.MappedTest): {'person_id':2, 'status':'new engineer'}, {'person_id':3, 'status':'old engineer'}, ) - + + @classmethod @testing.resolve_artifact_names - def test_poly_deferred(self): + def setup_mappers(cls): mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer') + + @testing.resolve_artifact_names + def test_poly_deferred(self): sess = create_session() [p1, e1, e2] = sess.query(Person).order_by(people.c.person_id).all() @@ -953,6 +958,34 @@ class PolymorphicExpireTest(_base.MappedTest): self.assert_sql_count(testing.db, go, 2) eq_(Engineer.name.get_history(e1), (['new engineer name'],(), ['engineer1'])) + @testing.resolve_artifact_names + def test_no_instance_key(self): + + sess = create_session() + e1 = sess.query(Engineer).get(2) + + sess.expire(e1, attribute_names=['name']) + sess.expunge(e1) + attributes.instance_state(e1).key = None + assert 'name' not in e1.__dict__ + sess.add(e1) + assert e1.name == 'engineer1' + + @testing.resolve_artifact_names + def test_no_instance_key(self): + # same as test_no_instance_key, but the PK columns + # are absent. ensure an error is raised. + sess = create_session() + e1 = sess.query(Engineer).get(2) + + sess.expire(e1, attribute_names=['name', 'person_id']) + sess.expunge(e1) + attributes.instance_state(e1).key = None + assert 'name' not in e1.__dict__ + sess.add(e1) + assert_raises(sa_exc.InvalidRequestError, getattr, e1, 'name') + + class ExpiredPendingTest(_fixtures.FixtureTest): run_define_tables = 'once' run_setup_classes = 'once'