From: Mike Bayer Date: Thu, 2 Mar 2006 01:45:31 +0000 (+0000) Subject: added objectstore.refresh(), including supporting changes in mapper, attributes,... X-Git-Tag: rel_0_1_3~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9985d21385ca8a478dcdd7d989982295775d0383;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git added objectstore.refresh(), including supporting changes in mapper, attributes, util --- diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py index bea125371b..08178324fe 100644 --- a/lib/sqlalchemy/attributes.py +++ b/lib/sqlalchemy/attributes.py @@ -65,6 +65,8 @@ class PropHistory(object): self.extension = extension def gethistory(self, *args, **kwargs): return self + def clear(self): + del self.obj.__dict__[self.key] def history_contains(self, obj): return self.orig is obj or self.obj.__dict__[self.key] is obj def setattr_clean(self, value): @@ -314,7 +316,8 @@ class AttributeManager(object): pass def remove(self, obj): - """not sure what this is.""" + """called when an object is totally being removed from memory""" + # currently a no-op since the state of the object is attached to the object itself pass def create_history(self, obj, key, uselist, callable_=None, **kwargs): @@ -350,6 +353,8 @@ class AttributeManager(object): When the attribute is next accessed, a new container will be created via the class-level history container definition.""" try: + x = self.attribute_history(obj)[key] + x.clear() del self.attribute_history(obj)[key] except KeyError: pass diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index b4e3b84261..85fe8dc696 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -235,6 +235,7 @@ class Mapper(object): def instances(self, cursor, *mappers, **kwargs): limit = kwargs.get('limit', None) offset = kwargs.get('offset', None) + populate_existing = kwargs.get('populate_existing', False) result = util.HistoryArraySet() if len(mappers): @@ -247,7 +248,7 @@ class Mapper(object): row = cursor.fetchone() if row is None: break - self._instance(row, imap, result) + self._instance(row, imap, result, populate_existing=populate_existing) i = 0 for m in mappers: m._instance(row, imap, otherresults[i]) @@ -270,21 +271,25 @@ class Mapper(object): #print "key: " + repr(key) + " ident: " + repr(ident) return self._get(key, ident) - def _get(self, key, ident=None): - try: - return objectstore.get_session()._get(key) - except KeyError: - if ident is None: - ident = key[2] - i = 0 - params = {} - for primary_key in self.pks_by_table[self.table]: - params["pk_"+primary_key.key] = ident[i] - i += 1 + def _get(self, key, ident=None, reload=False): + if not reload: try: - return self.select(self._get_clause, params=params)[0] - except IndexError: - return None + return objectstore.get_session()._get(key) + except KeyError: + pass + + if ident is None: + ident = key[1] + i = 0 + params = {} + for primary_key in self.pks_by_table[self.table]: + params["pk_"+primary_key.key] = ident[i] + i += 1 + try: + statement = self._compile(self._get_clause) + return self._select_statement(statement, params=params, populate_existing=reload)[0] + except IndexError: + return None def identity_key(self, *primary_key): @@ -449,10 +454,7 @@ class Mapper(object): def select_whereclause(self, whereclause=None, params=None, **kwargs): statement = self._compile(whereclause, **kwargs) - if params is not None: - return self.select_statement(statement, **params) - else: - return self.select_statement(statement) + return self._select_statement(statement, params=params) def count(self, whereclause=None, params=None, **kwargs): s = self.table.count(whereclause) @@ -462,13 +464,18 @@ class Mapper(object): return s.scalar() def select_statement(self, statement, **params): - statement.use_labels = True - return self.instances(statement.execute(**params)) + return self._select_statement(statement, params=params) def select_text(self, text, **params): t = sql.text(text, engine=self.primarytable.engine) return self.instances(t.execute(**params)) + def _select_statement(self, statement, params=None, **kwargs): + statement.use_labels = True + if params is None: + params = {} + return self.instances(statement.execute(**params), **kwargs) + def _getpropbycolumn(self, column): try: prop = self.columntoproperty[column.original] @@ -722,11 +729,10 @@ class Mapper(object): isnew = False if populate_existing: - isnew = not imap.has_key(identitykey) - if isnew: + if not imap.has_key(identitykey): imap[identitykey] = instance for prop in self.props.values(): - prop.execute(instance, row, identitykey, imap, isnew) + prop.execute(instance, row, identitykey, imap, True) if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing): if result is not None: diff --git a/lib/sqlalchemy/mapping/objectstore.py b/lib/sqlalchemy/mapping/objectstore.py index df5f8d6b1d..d0e1573c68 100644 --- a/lib/sqlalchemy/mapping/objectstore.py +++ b/lib/sqlalchemy/mapping/objectstore.py @@ -144,6 +144,10 @@ class Session(object): if self.parent_uow is None: self.uow.commit() + def refresh(self, *obj): + for o in obj: + self.uow.refresh(o) + def register_clean(self, obj): self._bind_to(obj) self.uow.register_clean(obj) @@ -221,6 +225,11 @@ def clear(): current mapped object instances, as they are no longer in the Identity Map.""" get_session().clear() +def refresh(*obj): + """reloads the state of this object from the database, and cancels any in-memory + changes.""" + get_session().refresh(*obj) + def delete(*obj): """registers the given objects as to be deleted upon the next commit""" s = get_session().delete(*obj) @@ -308,6 +317,10 @@ class UnitOfWork(object): def _put(self, key, obj): self.identity_map[key] = obj + def refresh(self, obj): + self.rollback_object(obj) + object_mapper(obj)._get(obj._instance_key, reload=True) + def has_key(self, key): """returns True if the given key is present in this UnitOfWork's identity map.""" return self.identity_map.has_key(key) diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index a0e0e244c6..618900fcfa 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -327,6 +327,10 @@ class HistoryArraySet(UserList.UserList): self.records = {} for l in list: self.append_nohistory(l) + def clear(self): + """clears the list and removes all history.""" + self.data[:] = [] + self.records = {} def added_items(self): """returns a list of items that have been added since the last "committed" state.""" return [key for key in self.data if self.records[key] is True]