From: Mike Bayer Date: Mon, 11 Feb 2008 19:22:34 +0000 (+0000) Subject: - added expire_all() method to Session. Calls expire() X-Git-Tag: rel_0_4_3~15 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=6f9aa3a9003d4d63348bc56f612690a153da640c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - added expire_all() method to Session. Calls expire() for all persistent instances. This is handy in conjunction with ..... - instances which have been partially or fully expired will have their expired attributes populated during a regular Query operation which affects those objects, preventing a needless second SQL statement for each instance. --- diff --git a/CHANGES b/CHANGES index e94826405a..e28563201e 100644 --- a/CHANGES +++ b/CHANGES @@ -105,7 +105,16 @@ CHANGES - The proper error message is raised when trying to access expired instance attributes with no session present - + + - added expire_all() method to Session. Calls expire() + for all persistent instances. This is handy in conjunction + with ..... + + - instances which have been partially or fully expired + will have their expired attributes populated during a regular + Query operation which affects those objects, preventing + a needless second SQL statement for each instance. + - Dynamic relations, when referenced, create a strong reference to the parent object so that the query still has a parent to call against even if the parent is only created diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index e08a1a0c2c..5ae79e4323 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -775,7 +775,14 @@ class InstanceState(object): serializable. """ instance = self.obj() - self.class_._class_state.deferred_scalar_loader(instance, [k for k in self.expired_attributes if k in self.unmodified]) + + unmodified = self.unmodified + self.class_._class_state.deferred_scalar_loader(instance, [ + attr.impl.key for attr in _managed_attributes(self.class_) if + attr.impl.accepts_scalar_loader and + attr.impl.key in self.expired_attributes and + attr.impl.key in unmodified + ]) for k in self.expired_attributes: self.callables.pop(k, None) self.expired_attributes.clear() @@ -798,20 +805,18 @@ class InstanceState(object): if attribute_names is None: for attr in _managed_attributes(self.class_): self.dict.pop(attr.impl.key, None) - + self.expired_attributes.add(attr.impl.key) if attr.impl.accepts_scalar_loader: self.callables[attr.impl.key] = self - self.expired_attributes.add(attr.impl.key) self.committed_state = {} else: for key in attribute_names: self.dict.pop(key, None) self.committed_state.pop(key, None) - + self.expired_attributes.add(key) if getattr(self.class_, key).impl.accepts_scalar_loader: self.callables[key] = self - self.expired_attributes.add(key) def reset(self, key): """remove the given attribute and any callables associated with it.""" diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index d78973e942..85aec2f447 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1371,13 +1371,22 @@ class Mapper(object): if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: self.populate_instance(context, instance, row, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) + + else: + attrs = getattr(state, 'expired_attributes', None) + # populate attributes on non-loading instances which have been expired + # TODO: also support deferred attributes here [ticket:870] + if attrs: + if state in context.partials: + isnew = False + attrs = context.partials[state] + else: + isnew = True + attrs = state.expired_attributes.intersection(state.unmodified) + context.partials[state] = attrs #<-- allow query.instances to commit the subset of attrs -# NOTYET: populate attributes on non-loading instances which have been expired, deferred, etc. -# elif getattr(state, 'expired_attributes', None): # TODO: base off total set of unloaded attributes, not just exp -# attrs = state.expired_attributes.intersection(state.unmodified) -# if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: -# self.populate_instance(context, instance, row, only_load_props=attrs, instancekey=identitykey, isnew=isnew) -# context.partials.add((state, attrs)) <-- allow query.instances to commit the subset of attrs + if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: + self.populate_instance(context, instance, row, only_load_props=attrs, instancekey=identitykey, isnew=isnew) if result is not None and ('append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE): result.append(instance) @@ -1448,10 +1457,10 @@ class Mapper(object): if self.non_primary: selectcontext.attributes[('populating_mapper', instance._state)] = self - def _post_instance(self, selectcontext, state): + def _post_instance(self, selectcontext, state, **kwargs): post_processors = selectcontext.attributes[('post_processors', self, None)] for p in post_processors: - p(state.obj()) + p(state.obj(), **kwargs) def _get_poly_select_loader(self, selectcontext, row): """set up attribute loaders for 'select' and 'deferred' polymorphic loading. @@ -1475,11 +1484,13 @@ class Mapper(object): identitykey = self.identity_key_from_instance(instance) + only_load_props = flags.get('only_load_props', None) + params = {} for c, bind in param_names: params[bind] = self._get_attr_by_column(instance, c) row = selectcontext.session.connection(self).execute(statement, params).fetchone() - self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True) + self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True, only_load_props=only_load_props) return post_execute elif hosted_mapper.polymorphic_fetch == 'deferred': from sqlalchemy.orm.strategies import DeferredColumnLoader @@ -1494,6 +1505,12 @@ class Mapper(object): props = [prop for prop in [self._get_col_to_prop(col) for col in statement.inner_columns] if prop.key not in instance.__dict__] keys = [p.key for p in props] + + only_load_props = flags.get('only_load_props', None) + if only_load_props: + keys = util.Set(keys).difference(only_load_props) + props = [p for p in props if p.key in only_load_props] + for prop in props: strategy = prop._get_strategy(DeferredColumnLoader) instance._state.set_callable(prop.key, strategy.setup_loader(instance, props=keys, create_statement=create_statement)) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 35f632f1e7..2bb87ea715 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -903,6 +903,7 @@ class Query(object): while True: context.progress = util.Set() + context.partials = {} if self._yield_per: fetch = cursor.fetchmany(self._yield_per) @@ -927,7 +928,11 @@ class Query(object): for ii in context.progress: context.attributes.get(('populating_mapper', ii), _state_mapper(ii))._post_instance(context, ii) ii.commit_all() - + + for ii, attrs in context.partials.items(): + context.attributes.get(('populating_mapper', ii), _state_mapper(ii))._post_instance(context, ii, only_load_props=attrs) + ii.commit(attrs) + for row in rows: yield row diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index c75b786644..8f85a496c4 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -820,7 +820,14 @@ class Session(object): if self.query(_object_mapper(instance))._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None: raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance)) - + + def expire_all(self): + """Expires all persistent instances within this Session. + + """ + for state in self.identity_map.all_states(): + _expire_state(state, None) + def expire(self, instance, attribute_names=None): """Expire the attributes on the given instance. @@ -829,13 +836,6 @@ class Session(object): to the database which will refresh all attributes with their current value. - Lazy-loaded relational attributes will remain lazily loaded, so that - triggering one will incur the instance-wide refresh operation, followed - immediately by the lazy load of that attribute. - - Eagerly-loaded relational attributes will eagerly load within the - single refresh operation. - The ``attribute_names`` argument is an iterable collection of attribute names indicating a subset of attributes to be expired. diff --git a/test/orm/expire.py b/test/orm/expire.py index 3394c751bd..545f01234d 100644 --- a/test/orm/expire.py +++ b/test/orm/expire.py @@ -6,6 +6,7 @@ from sqlalchemy import exceptions from sqlalchemy.orm import * from testlib import * from testlib.fixtures import * +import gc class ExpireTest(FixtureTest): keep_mappers = False @@ -39,16 +40,13 @@ class ExpireTest(FixtureTest): # object isnt refreshed yet, using dict to bypass trigger assert u.__dict__.get('name') != 'jack' - if False: - # NOTYET: need to implement unconditional population - # of expired attriutes in mapper._instances() - sess.query(User).all() - # test that it refreshed - assert u.__dict__['name'] == 'jack' + sess.query(User).all() + # test that it refreshed + assert u.__dict__['name'] == 'jack' - def go(): - assert u.name == 'jack' - self.assert_sql_count(testing.db, go, 0) + def go(): + assert u.name == 'jack' + self.assert_sql_count(testing.db, go, 0) def test_expire_doesntload_on_set(self): mapper(User, users) @@ -122,16 +120,21 @@ class ExpireTest(FixtureTest): assert o.isopen == 1 assert o.description == 'some new description' - if False: - # NOTYET: need to implement unconditional population - # of expired attriutes in mapper._instances() - sess.expire(o, ['isopen', 'description']) - sess.query(Order).all() - del o.isopen - def go(): - assert o.isopen is None - self.assert_sql_count(testing.db, go, 0) + sess.expire(o, ['isopen', 'description']) + sess.query(Order).all() + del o.isopen + def go(): + assert o.isopen is None + self.assert_sql_count(testing.db, go, 0) + o.isopen=14 + sess.expire(o) + o.description = 'another new description' + sess.query(Order).all() + assert o.isopen == 1 + assert o.description == 'another new description' + + def test_expire_committed(self): """test that the committed state of the attribute receives the most recent DB data""" mapper(Order, orders) @@ -200,7 +203,7 @@ class ExpireTest(FixtureTest): assert u.addresses[0].email_address == 'jack@bean.com' assert u.name == 'jack' # two loads, since relation() + scalar are - # separate right now + # separate right now on per-attribute load self.assert_sql_count(testing.db, go, 2) assert 'name' in u.__dict__ assert 'addresses' in u.__dict__ @@ -209,6 +212,50 @@ class ExpireTest(FixtureTest): assert 'name' not in u.__dict__ assert 'addresses' not in u.__dict__ + def go(): + sess.query(User).filter_by(id=7).one() + assert u.addresses[0].email_address == 'jack@bean.com' + assert u.name == 'jack' + # one load, since relation() + scalar are + # together when eager load used with Query + self.assert_sql_count(testing.db, go, 1) + + def test_relation_changes_preserved(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', lazy=False), + }) + mapper(Address, addresses) + sess = create_session() + u = sess.query(User).get(8) + sess.expire(u, ['name', 'addresses']) + u.addresses + assert 'name' not in u.__dict__ + del u.addresses[1] + u.name + assert 'name' in u.__dict__ + assert len(u.addresses) == 2 + + def test_eagerload_props_dontload(self): + # relations currently have to load separately from scalar instances. the use case is: + # expire "addresses". then access it. lazy load fires off to load "addresses", but needs + # foreign key or primary key attributes in order to lazy load; hits those attributes, + # such as below it hits "u.id". "u.id" triggers full unexpire operation, eagerloads + # addresses since lazy=False. this is all wihtin lazy load which fires unconditionally; + # so an unnecessary eagerload (or lazyload) was issued. would prefer not to complicate + # lazyloading to "figure out" that the operation should be aborted right now. + + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', lazy=False), + }) + mapper(Address, addresses) + sess = create_session() + u = sess.query(User).get(8) + sess.expire(u) + u.id + assert 'addresses' not in u.__dict__ + u.addresses + assert 'addresses' in u.__dict__ + def test_expire_synonym(self): mapper(User, users, properties={ 'uname':synonym('name') @@ -361,6 +408,25 @@ class ExpireTest(FixtureTest): # doing it that way right now #self.assert_sql_count(testing.db, go, 0) + def test_relations_load_on_query(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user'), + }) + mapper(Address, addresses) + + sess = create_session() + u = sess.query(User).get(8) + assert 'name' in u.__dict__ + u.addresses + assert 'addresses' in u.__dict__ + + sess.expire(u, ['name', 'addresses']) + assert 'name' not in u.__dict__ + assert 'addresses' not in u.__dict__ + sess.query(User).options(eagerload('addresses')).filter_by(id=8).all() + assert 'name' in u.__dict__ + assert 'addresses' in u.__dict__ + def test_partial_expire_deferred(self): mapper(Order, orders, properties={ 'description':deferred(orders.c.description) @@ -426,8 +492,149 @@ class ExpireTest(FixtureTest): assert o.description == 'order 3' assert o.isopen == 1 self.assert_sql_count(testing.db, go, 1) + + def test_eagerload_query_refreshes(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', lazy=False), + }) + mapper(Address, addresses) + + sess = create_session() + u = sess.query(User).get(8) + assert len(u.addresses) == 3 + sess.expire(u) + assert 'addresses' not in u.__dict__ + print "-------------------------------------------" + sess.query(User).filter_by(id=8).all() + assert 'addresses' in u.__dict__ + assert len(u.addresses) == 3 + + def test_expire_all(self): + mapper(User, users, properties={ + 'addresses':relation(Address, backref='user', lazy=False), + }) + mapper(Address, addresses) + + sess = create_session() + userlist = sess.query(User).all() + assert fixtures.user_address_result == userlist + assert len(list(sess)) == 9 + sess.expire_all() + gc.collect() + assert len(list(sess)) == 4 # since addresses were gc'ed + + userlist = sess.query(User).all() + u = userlist[1] + assert fixtures.user_address_result == userlist + assert len(list(sess)) == 9 + +class PolymorphicExpireTest(ORMTest): + keep_data = True + + def define_tables(self, metadata): + global people, engineers, Person, Engineer + + people = Table('people', metadata, + Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), + Column('name', String(50)), + Column('type', String(30))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('status', String(30)), + ) + + class Person(Base): + pass + class Engineer(Person): + pass + + def insert_data(self): + people.insert().execute( + {'person_id':1, 'name':'person1', 'type':'person'}, + {'person_id':2, 'name':'engineer1', 'type':'engineer'}, + {'person_id':3, 'name':'engineer2', 'type':'engineer'}, + ) + engineers.insert().execute( + {'person_id':2, 'status':'new engineer'}, + {'person_id':3, 'status':'old engineer'}, + ) + + def test_poly_select(self): + mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person') + mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer') + + sess = create_session() + [p1, e1, e2] = sess.query(Person).order_by(people.c.person_id).all() + + sess.expire(p1) + sess.expire(e1, ['status']) + sess.expire(e2) + + for p in [p1, e2]: + assert 'name' not in p.__dict__ + + assert 'name' in e1.__dict__ + assert 'status' not in e2.__dict__ + assert 'status' not in e1.__dict__ + + e1.name = 'new engineer name' + + def go(): + sess.query(Person).all() + self.assert_sql_count(testing.db, go, 3) + + for p in [p1, e1, e2]: + assert 'name' in p.__dict__ + + assert 'status' in e2.__dict__ + assert 'status' in e1.__dict__ + def go(): + assert e1.name == 'new engineer name' + assert e2.name == 'engineer2' + assert e1.status == 'new engineer' + self.assert_sql_count(testing.db, go, 0) + self.assertEquals(Engineer.name.get_history(e1), (['new engineer name'], [], ['engineer1'])) + + def test_poly_deferred(self): + mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person', polymorphic_fetch='deferred') + mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer') + + sess = create_session() + [p1, e1, e2] = sess.query(Person).order_by(people.c.person_id).all() + + sess.expire(p1) + sess.expire(e1, ['status']) + sess.expire(e2) + + for p in [p1, e2]: + assert 'name' not in p.__dict__ + + assert 'name' in e1.__dict__ + assert 'status' not in e2.__dict__ + assert 'status' not in e1.__dict__ + + e1.name = 'new engineer name' + + def go(): + sess.query(Person).all() + self.assert_sql_count(testing.db, go, 1) + + for p in [p1, e1, e2]: + assert 'name' in p.__dict__ + assert 'status' not in e2.__dict__ + assert 'status' not in e1.__dict__ + def go(): + assert e1.name == 'new engineer name' + assert e2.name == 'engineer2' + assert e1.status == 'new engineer' + assert e2.status == 'old engineer' + self.assert_sql_count(testing.db, go, 2) + self.assertEquals(Engineer.name.get_history(e1), (['new engineer name'], [], ['engineer1'])) + + class RefreshTest(FixtureTest): keep_mappers = False refresh_data = True