From: Mike Bayer Date: Sun, 28 Aug 2005 22:03:57 +0000 (+0000) Subject: (no commit message) X-Git-Tag: rel_0_1_0~824 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ca3956f5ccf5c3d02ae64d57a6dfe24c26f5983c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git --- diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 9e662f455a..db1191a9bc 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -55,7 +55,7 @@ class SQLEngine(schema.SchemaEngine): self._pool = sqlalchemy.pool.manage(self.dbapi()).get_pool(*cargs, **cparams) self._echo = echo self.context = util.ThreadLocal() - + def schemagenerator(self, proxy, **params): raise NotImplementedError() @@ -64,7 +64,7 @@ class SQLEngine(schema.SchemaEngine): def reflecttable(self, table): raise NotImplementedError() - + def columnimpl(self, column): return sql.ColumnSelectable(column) @@ -72,22 +72,52 @@ class SQLEngine(schema.SchemaEngine): """returns a thread-local map of the generated primary keys corresponding to the most recent insert statement. keys are the names of columns.""" raise NotImplementedError() - + def connect_args(self): raise NotImplementedError() - + def dbapi(self): raise NotImplementedError() def compile(self, statement, bindparams): raise NotImplementedError() + def do_begin(self, connection): + """implementations might want to put logic here for turning autocommit on/off, etc.""" + pass + def do_rollback(self, connection): + """implementations might want to put logic here for turning autocommit on/off, etc.""" + connection.rollback() + def do_commit(self, connection): + """implementations might want to put logic here for turning autocommit on/off, etc.""" + connection.commit() + def proxy(self): return lambda s, p = None: self.execute(s, p) - + def connection(self): return self._pool.connect() + def multi_transaction(self, tables, func): + """provides a transaction boundary across tables which may be in multiple databases. + + clearly, this approach only goes so far, such as if database A commits, then database B commits + and fails, A is already committed. Any failure conditions have to be raised before anyone + commits for this to be useful.""" + engines = util.Set() + for table in tables: + engines.append(table.engine) + for engine in engines: + engine.begin() + try: + func() + except: + for engine in engines: + engine.rollback() + raise + for engine in engines: + engine.commit() + def transaction(self, func): self.begin() try: @@ -96,10 +126,12 @@ class SQLEngine(schema.SchemaEngine): self.rollback() raise self.commit() - + + def begin(self): if getattr(self.context, 'transaction', None) is None: conn = self.connection() + self.do_begin(conn) self.context.transaction = conn self.context.tcount = 1 else: @@ -107,7 +139,7 @@ class SQLEngine(schema.SchemaEngine): def rollback(self): if self.context.transaction is not None: - self.context.transaction.rollback() + self.do_rollback(self.context.transaction) self.context.transaction = None self.context.tcount = None @@ -116,7 +148,7 @@ class SQLEngine(schema.SchemaEngine): count = self.context.tcount - 1 self.context.tcount = count if count == 0: - self.context.transaction.commit() + self.do_commit(self.context.transaction) self.context.transaction = None self.context.tcount = None diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 283bfb97bc..bef7474f24 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -63,10 +63,10 @@ def identitymap(): def globalidentity(): return _global_identitymap - + def eagerload(name): return EagerLazySwitcher(name, toeager = True) - + def lazyload(name): return EagerLazySwitcher(name, toeager = False) @@ -104,7 +104,7 @@ class Mapper(object): if isroot: self.init(self) - + def hash_key(self): return mapper_hash_key( self.class_, @@ -181,7 +181,7 @@ class Mapper(object): for option in options: option.process(mapper) return _mappers.setdefault(hashkey, mapper) - + def select(self, arg = None, **params): """selects instances of the object from the database. @@ -196,28 +196,37 @@ class Mapper(object): return self._select_statement(arg, **params) else: return self._select_whereclause(arg, **params) - + def save(self, obj, traverse = True, refetch = False): - """saves the object. based on the existence of its primary key, either inserts or updates. + """saves the object across all its primary tables. + based on the existence of the primary key for each table, either inserts or updates. primary key is determined by the underlying database engine's sequence methodology. - traverse indicates attached objects should be saved as well. + the traverse flag indicates attached objects should be saved as well. if smart attributes are being used for the object, the "dirty" flag, or the absense of the attribute, determines if the item is saved. if smart attributes are not being used, the item is saved unconditionally. """ - # TODO: support multi-table saves + if getattr(obj, 'dirty', True): - for table in self.tables: - for col in table.columns: - if getattr(obj, col.key, None) is None: - self.insert(obj, table) - break - else: - self.update(obj, table) + f = def(): + for table in self.tables: + for col in table.columns: + if getattr(obj, col.key, None) is None: + self.insert(obj, table) + break + else: + self.update(obj, table) + + for prop in self.props.values(): + prop.save(obj, traverse, refetch) + self.transaction(f) + else: + for prop in self.props.values(): + prop.save(obj, traverse, refetch) - for prop in self.props.values(): - prop.save(obj, traverse, refetch) + def transaction(self, f): + return self.table.engine.multi_transaction(self.tables, f) def remove(self, obj, traverse = True): """removes the object. traverse indicates attached objects should be removed as well.""" @@ -433,7 +442,7 @@ class PropertyLoader(MapperProperty): if self.primaryjoin is None: self.primaryjoin = match_primaries(parent.selectable, self.target) - def save(self, object, traverse, refetch): + def save(self, obj, traverse, refetch): # if a mapping table does not exist, save a row for all objects # in our list normally, setting their primary keys # else, determine the foreign key column in our table, set it to the parent @@ -441,15 +450,20 @@ class PropertyLoader(MapperProperty): # if a mapping table exists, determine the two foreign key columns # in the mapping table, set the two values, and insert that row, for # each row in the list - if self.secondary is None: - self.mapper.save(object) - else: - # TODO: crap, we dont have a simple version of what object props/cols match to which - pass + for child in getattr(obj, self.key): + setter = ForeignKeySetter(obj, child) + self.primaryjoin.accept_visitor(setter) + self.mapper.save(child) def delete(self): self.mapper.delete() +class ForeignKeySetter(ClauseVisitor): + def visit_binary(self, binary): + if binary.operator == '==': + if binary.left.table == self.primarytable and binary.right.table == self.secondarytable: + setattr(self.child, binary.left.colname, getattr(obj, binary.right.colname)) + class LazyLoader(PropertyLoader): def init(self, key, parent, root): diff --git a/test/mapper.py b/test/mapper.py index 7429070a36..d1f29e4389 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -250,18 +250,18 @@ class SaveTest(PersistTest): m = mapper(User, users, echo=True) m.save(u) m.save(u2) - + # assert the first one retreives the same from the identity map nu = m.get(u.user_id) self.assert_(u is nu) - + # clear out the identity map, so next get forces a SELECT m.identitymap.clear() # check it again, identity should be different but ids the same nu = m.get(u.user_id) self.assert_(u is not nu and u.user_id == nu.user_id and nu.user_name == 'savetester') - + # change first users name and save u.user_name = 'modifiedname' m.save(u) @@ -292,6 +292,6 @@ class SaveTest(PersistTest): print repr(usertable) addresstable = engine.ResultProxy(addresses.select().execute()).fetchall() print repr(addresstable) - + if __name__ == "__main__": - unittest.main() + unittest.main()