From a8c232258805ddcd5db464db06ed42a73a50c445 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 1 Mar 2008 22:30:02 +0000 Subject: [PATCH] - state.commit() and state.commit_all() now reconcile the current dict against expired_attributes and unset the expired flag for those attributes. This is partially so that attributes are not needlessly marked as expired after a two-phase inheritance load. - fixed bug which was introduced in 0.4.3, whereby loading an already-persistent instance mapped with joined table inheritance would trigger a useless "secondary" load from its joined table, when using the default "select" polymorphic_fetch. This was due to attributes being marked as expired during its first load and not getting unmarked from the previous "secondary" load. Attributes are now unexpired based on presence in __dict__ after any load or commit operation succeeds. --- CHANGES | 10 +++++++ lib/sqlalchemy/orm/attributes.py | 45 ++++++++++++++++++++++++------- lib/sqlalchemy/orm/mapper.py | 6 ++--- test/orm/expire.py | 2 ++ test/orm/inheritance/polymorph.py | 31 ++++++++++++--------- 5 files changed, 67 insertions(+), 27 deletions(-) diff --git a/CHANGES b/CHANGES index 3e5d36c20f..4faa8ee866 100644 --- a/CHANGES +++ b/CHANGES @@ -45,6 +45,16 @@ CHANGES - Fixed potential generative bug when the same Query was used to generate multiple Query objects using join(). + - fixed bug which was introduced in 0.4.3, whereby loading an + already-persistent instance mapped with joined table inheritance + would trigger a useless "secondary" load from its joined + table, when using the default "select" polymorphic_fetch. + This was due to attributes being marked as expired + during its first load and not getting unmarked from the + previous "secondary" load. Attributes are now unexpired + based on presence in __dict__ after any load or commit + operation succeeds. + - deprecated Query methods apply_sum(), apply_max(), apply_min(), apply_avg(). Better methodologies are coming.... diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 298a7f5119..5c5781d4e6 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -257,9 +257,6 @@ class AttributeImpl(object): """set an attribute value on the given instance and 'commit' it.""" state.commit_attr(self, value) - # remove per-instance callable, if any - state.callables.pop(self.key, None) - state.dict[self.key] = value return value class ScalarAttributeImpl(AttributeImpl): @@ -672,6 +669,9 @@ class ClassState(object): self.attrs = {} self.has_mutable_scalars = False +import sets +_empty_set = sets.ImmutableSet() + class InstanceState(object): """tracks state information at the instance level.""" @@ -687,6 +687,7 @@ class InstanceState(object): self.appenders = {} self.instance_dict = None self.runid = None + self.expired_attributes = _empty_set def __cleanup(self, ref): # tiptoe around Python GC unpredictableness @@ -751,7 +752,7 @@ class InstanceState(object): return None def __getstate__(self): - return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj(), 'expired_attributes':getattr(self, 'expired_attributes', None), 'callables':self.callables} + return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj(), 'expired_attributes':self.expired_attributes, 'callables':self.callables} def __setstate__(self, state): self.committed_state = state['committed_state'] @@ -764,8 +765,7 @@ class InstanceState(object): self.callables = state['callables'] self.runid = None self.appenders = {} - if state['expired_attributes'] is not None: - self.expire_attributes(state['expired_attributes']) + self.expired_attributes = state['expired_attributes'] def initialize(self, key): getattr(self.class_, key).impl.initialize(self) @@ -780,7 +780,6 @@ class InstanceState(object): serializable. """ instance = self.obj() - unmodified = self.unmodified self.class_._class_state.deferred_scalar_loader(instance, [ attr.impl.key for attr in _managed_attributes(self.class_) if @@ -804,8 +803,7 @@ class InstanceState(object): unmodified = property(unmodified) def expire_attributes(self, attribute_names): - if not hasattr(self, 'expired_attributes'): - self.expired_attributes = util.Set() + self.expired_attributes = util.Set(self.expired_attributes) if attribute_names is None: for attr in _managed_attributes(self.class_): @@ -829,18 +827,29 @@ class InstanceState(object): self.callables.pop(key, None) def commit_attr(self, attr, value): + """set the value of an attribute and mark it 'committed'.""" + if hasattr(attr, 'commit_to_state'): attr.commit_to_state(self, value) else: self.committed_state.pop(attr.key, None) + self.dict[attr.key] = value self.pending.pop(attr.key, None) self.appenders.pop(attr.key, None) + + # we have a value so we can also unexpire it + self.callables.pop(attr.key, None) + if attr.key in self.expired_attributes: + self.expired_attributes.remove(attr.key) def commit(self, keys): """commit all attributes named in the given list of key names. This is used by a partial-attribute load operation to mark committed those attributes which were refreshed from the database. + + Attributes marked as "expired" can potentially remain "expired" after this step + if a value was not populated in state.dict. """ if self.class_._class_state.has_mutable_scalars: @@ -857,12 +866,22 @@ class InstanceState(object): self.committed_state.pop(key, None) self.pending.pop(key, None) self.appenders.pop(key, None) - + + # unexpire attributes which have loaded + for key in self.expired_attributes.intersection(keys): + if key in self.dict: + self.expired_attributes.remove(key) + self.callables.pop(key, None) + + def commit_all(self): """commit all attributes unconditionally. This is used after a flush() or a regular instance load or refresh operation to mark committed all populated attributes. + + Attributes marked as "expired" can potentially remain "expired" after this step + if a value was not populated in state.dict. """ self.committed_state = {} @@ -870,6 +889,12 @@ class InstanceState(object): self.pending = {} self.appenders = {} + # unexpire attributes which have loaded + for key in list(self.expired_attributes): + if key in self.dict: + self.expired_attributes.remove(key) + self.callables.pop(key, None) + if self.class_._class_state.has_mutable_scalars: for attr in _managed_attributes(self.class_): if hasattr(attr.impl, 'commit_to_state') and attr.impl.key in self.dict: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 297d222466..f89830c02c 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1373,10 +1373,9 @@ class Mapper(object): self.populate_instance(context, instance, row, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) else: - attrs = getattr(state, 'expired_attributes', None) # populate attributes on non-loading instances which have been expired # TODO: also support deferred attributes here [ticket:870] - if attrs: + if state.expired_attributes: if state in context.partials: isnew = False attrs = context.partials[state] @@ -1483,7 +1482,7 @@ class Mapper(object): self.__log_debug("Post query loading instance " + instance_str(instance)) identitykey = self.identity_key_from_instance(instance) - + only_load_props = flags.get('only_load_props', None) params = {} @@ -1563,7 +1562,6 @@ object_session = None def _load_scalar_attributes(instance, attribute_names): mapper = object_mapper(instance) - global object_session if not object_session: from sqlalchemy.orm.session import object_session diff --git a/test/orm/expire.py b/test/orm/expire.py index 545f01234d..dca56dfb8f 100644 --- a/test/orm/expire.py +++ b/test/orm/expire.py @@ -39,10 +39,12 @@ class ExpireTest(FixtureTest): sess.expire(u) # object isnt refreshed yet, using dict to bypass trigger assert u.__dict__.get('name') != 'jack' + assert 'name' in u._state.expired_attributes sess.query(User).all() # test that it refreshed assert u.__dict__['name'] == 'jack' + assert 'name' not in u._state.expired_attributes def go(): assert u.name == 'jack' diff --git a/test/orm/inheritance/polymorph.py b/test/orm/inheritance/polymorph.py index faee633601..4b468e227c 100644 --- a/test/orm/inheritance/polymorph.py +++ b/test/orm/inheritance/polymorph.py @@ -268,7 +268,7 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co c.employees.append(Engineer(status='CGG', engineer_name='engineer2', primary_language='python', **{person_attribute_name:'wally'})) c.employees.append(Manager(status='ABA', manager_name='manager2', **{person_attribute_name:'jsmith'})) session.save(c) - print session.new + session.flush() session.clear() @@ -284,7 +284,6 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co def go(): c = session.query(Company).get(id) for e in c.employees: - print e, e._instance_key, e.company assert e._instance_key[0] == Person if include_base: assert sets.Set([(e.get_name(), getattr(e, 'status', None)) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('joesmith', None), ('wally', 'CGG'), ('jsmith', 'ABA')]) @@ -307,25 +306,31 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co # test selecting from the query, using the base mapped table (people) as the selection criterion. # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join" dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first() - dilbert2 = session.query(Engineer).filter(getattr(Person, person_attribute_name)=='dilbert').first() - assert dilbert is dilbert2 + assert dilbert is session.query(Engineer).filter(getattr(Person, person_attribute_name)=='dilbert').first() # test selecting from the query, joining against an alias of the base "people" table. test that # the "palias" alias does *not* get sucked up into the "person_join" conversion. palias = people.alias("palias") - session.query(Person).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first() - dilbert2 = session.query(Engineer).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first() - assert dilbert is dilbert2 - - session.query(Person).filter((Engineer.engineer_name=="engineer1") & (Engineer.person_id==people.c.person_id)).first() - - dilbert2 = session.query(Engineer).filter(Engineer.engineer_name=="engineer1")[0] - assert dilbert is dilbert2 - + assert dilbert is session.query(Person).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first() + assert dilbert is session.query(Engineer).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first() + assert dilbert is session.query(Person).filter((Engineer.engineer_name=="engineer1") & (engineers.c.person_id==people.c.person_id)).first() + assert dilbert is session.query(Engineer).filter(Engineer.engineer_name=="engineer1")[0] + dilbert.engineer_name = 'hes dibert!' session.flush() session.clear() + + if polymorphic_fetch == 'select': + def go(): + session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first() + self.assert_sql_count(testing.db, go, 2) + session.clear() + dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first() + def go(): + # assert that only primary table is queried for already-present-in-session + d = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first() + self.assert_sql_count(testing.db, go, 1) # save/load some managers/bosses b = Boss(status='BBB', manager_name='boss', golf_swing='fore', **{person_attribute_name:'daboss'}) -- 2.47.3