From: Mike Bayer Date: Sun, 28 Aug 2005 01:27:19 +0000 (+0000) Subject: dev X-Git-Tag: rel_0_1_0~828 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f91cfdb8ac8794ea8e9f348b21ebbbf553d8bd4b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git dev --- diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 42014ca21d..e0dcc58bd7 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -58,6 +58,7 @@ class ANSICompiler(sql.Compiled): self.froms = {} self.wheres = {} self.strings = {} + self.isinsert = False def get_from_text(self, obj): return self.froms[obj] @@ -200,6 +201,7 @@ class ANSICompiler(sql.Compiled): " ON " + self.get_str(join.onclause)) def visit_insert(self, insert_stmt): + self.isinsert = True colparams = insert_stmt.get_colparams(self.bindparams) for c in colparams: b = c[1] diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index fa5124ed11..f6d56d207e 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -52,15 +52,22 @@ class PGSQLEngine(ansisql.ANSISQLEngine): statement.accept_visitor(compiler) return compiler + def last_inserted_ids(self): + return self.context.last_inserted_ids + def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs): if compiled is None: return if getattr(compiled, "isinsert", False): + last_inserted_ids = [] for primary_key in compiled.statement.table.primary_keys: # pseudocode if echo is True or self._echo: self.log(primary_key.sequence.text) res = cursor.execute(primary_key.sequence.text) - parameters[primary_key.key] = res.fetchrow()[0] + newid = res.fetchrow()[0] + parameters[primary_key.key] = newid + last_inserted_ids.append(newid) + self.context.last_inserted_ids = last_inserted_ids def dbapi(self): return None @@ -73,10 +80,8 @@ class PGSQLEngine(ansisql.ANSISQLEngine): raise NotImplementedError() class PGCompiler(ansisql.ANSICompiler): - def visit_insert(self, insert): - self.isinsert = True - super(self).visit_insert(insert) - + pass + class PGColumnImpl(sql.ColumnSelectable): def get_specification(self): coltype = self.column.type diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 315374a6d4..475f687378 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -49,8 +49,13 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): self.opts = opts or {} ansisql.ANSISQLEngine.__init__(self, **params) + def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs): + if compiled is None: return + if getattr(compiled, "isinsert", False): + self.context.last_inserted_ids = [cursor.lastrowid] + def last_inserted_ids(self): - pass + return self.context.last_inserted_ids def connect_args(self): return ([self.filename], self.opts) @@ -81,5 +86,7 @@ class SQLiteColumnImpl(sql.ColumnSelectable): else: key = coltype.__class__ - return self.name + " " + colspecs[key] % {'precision': getattr(coltype, 'precision', None), 'length' : getattr(coltype, 'length', None)} - + colspec = self.name + " " + colspecs[key] % {'precision': getattr(coltype, 'precision', None), 'length' : getattr(coltype, 'length', None)} + if self.column.primary_key: + colspec += " PRIMARY KEY" + return colspec diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 22cc034348..e6183c328f 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -150,8 +150,9 @@ class SQLEngine(schema.SchemaEngine): class ResultProxy: - def __init__(self, cursor): + def __init__(self, cursor, echo = False): self.cursor = cursor + self.echo = echo metadata = cursor.description self.props = {} i = 0 @@ -164,7 +165,7 @@ class ResultProxy: def fetchone(self): row = self.cursor.fetchone() if row is not None: - #print repr(row) + if self.echo: print repr(row) return RowProxy(self, row) else: return None diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 4d506f8e23..d8b6caf465 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -52,7 +52,7 @@ def relation_mapper(class_, selectable, secondary = None, primaryjoin = None, se _mappers = {} def mapper(*args, **params): hashkey = mapper_hash_key(*args, **params) - print "HASHKEY: " + hashkey + #print "HASHKEY: " + hashkey try: return _mappers[hashkey] except KeyError: @@ -121,7 +121,7 @@ class Mapper(object): def instances(self, cursor): result = [] - cursor = engine.ResultProxy(cursor) + cursor = engine.ResultProxy(cursor, echo = self.echo) localmap = {} while True: @@ -131,9 +131,30 @@ class Mapper(object): self._instance(row, localmap, result) return result - def get(self, id): - """returns an instance of the object based on the given ID.""" - pass + def get(self, *ident): + """returns an instance of the object based on the given identifier, or None + if not found. The *ident argument is a + list of primary keys in the order of the table def's primary keys.""" + key = self.identitymap.get_id_key(ident, self.class_, self.table, self.selectable) + try: + return self.identitymap[key] + except KeyError: + clause = sql.and_() + i = 0 + for primary_key in self.selectable.primary_keys: + # appending to the and_'s clause list directly to skip + # typechecks etc. + clause.clauses.append(primary_key == ident[i]) + i += 2 + try: + return self.select(clause)[0] + except IndexError: + return None + + def put(self, instance): + key = self.identitymap.get_instance_key(instance, self.class_, self.table, self.selectable) + self.identitymap[key] = instance + return key def compile(self, whereclause = None, **options): """works like select, except returns the SQL statement object without @@ -188,29 +209,45 @@ class Mapper(object): """removes the object. traverse indicates attached objects should be removed as well.""" pass - def insert(self, object): + def insert(self, obj): """inserts the object into its table, regardless of primary key being set. this is a lower-level operation than save.""" params = {} for col in self.table.columns: - params[col.key] = getattr(object, col.key) + params[col.key] = getattr(obj, col.key, None) ins = self.table.insert() ins.execute(**params) + + # TODO: unset dirty flag + + # populate new primary keys primary_keys = self.table.engine.last_inserted_ids() - # TODO: put the primary keys into the object props + index = 0 + for pk in self.table.primary_keys: + newid = primary_keys[index] + index += 1 + # TODO: do this via the ColumnProperty objects + setattr(obj, pk.key, newid) - def update(self, object): + self.put(obj) + + def update(self, obj): """inserts the object into its table, regardless of primary key being set. this is a lower-level operation than save.""" params = {} for col in self.table.columns: - params[col.key] = getattr(object, col.key) + params[col.key] = getattr(obj, col.key) upd = self.table.update() upd.execute(**params) - - def delete(self, object): + # TODO: unset dirty flag + + def delete(self, obj): """deletes the object's row from its table unconditionally. this is a lower-level operation than remove.""" + # delete dependencies ? + # delete row + # remove primary keys + # unset dirty flag pass class TableFinder(sql.ClauseVisitor): @@ -234,7 +271,7 @@ class Mapper(object): def _select_whereclause(self, whereclause = None, **params): statement = self._compile(whereclause) return self._select_statement(statement, **params) - + def _select_statement(self, statement, **params): statement.use_labels = True statement.echo = self.echo @@ -243,12 +280,13 @@ class Mapper(object): def _identity_key(self, row): return self.identitymap.get_key(row, self.class_, self.table, self.selectable) + def _instance(self, row, localmap, result): """pulls an object instance from the given row and appends it to the given result list. if the instance already exists in the given identity map, its not added. in either case, executes all the property loaders on the instance to also process extra information in the row.""" - + # create the instance if its not in the identity map, # else retrieve it identitykey = self._identity_key(row) @@ -323,7 +361,7 @@ class ColumnProperty(MapperProperty): def hash_key(self): return "ColumnProperty(%s)" % hash_key(self.column) - + def init(self, key, parent, root): self.key = key if root.use_smart_properties: @@ -363,8 +401,6 @@ def mapper_hash_key(class_, selectable, table = None, properties = None, identit ) ) - - class PropertyLoader(MapperProperty): def __init__(self, mapper, secondary, primaryjoin, secondaryjoin): self.mapper = mapper @@ -373,10 +409,10 @@ class PropertyLoader(MapperProperty): self.primaryjoin = primaryjoin self.secondaryjoin = secondaryjoin self._hash_key = "%s(%s, %s, %s, %s)" % (self.__class__.__name__, hash_key(mapper), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin)) - + def hash_key(self): return self._hash_key - + def init(self, key, parent, root): self.key = key self.mapper.init(root) @@ -397,12 +433,15 @@ class PropertyLoader(MapperProperty): # if a mapping table exists, determine the two foreign key columns # in the mapping table, set the two values, and insert that row, for # each row in the list - pass + if self.secondary is None: + self.mapper.save(object) + else: + # TODO: crap, we dont have a simple version of what object props/cols match to which + pass def delete(self): self.mapper.delete() - class LazyLoader(PropertyLoader): def init(self, key, parent, root): @@ -424,14 +463,22 @@ class LazyLoader(PropertyLoader): def execute(self, instance, row, identitykey, localmap, isduplicate): if not isduplicate: - def load(): - m = {} - for key, value in self.binds.iteritems(): - m[key] = row[key] - return self.mapper.select(self.lazywhere, **m) + setattr(instance, self.key, LazyLoadInstance(self, row)) + +class LazyLoadInstance(object): + """attached to a specific object instance to load related rows. this is implemetned + as a callable object, rather than a closure, to allow serialization of the target object""" + def __init__(self, lazyloader, row): + self.params = {} + for key, value in lazyloader.binds.iteritems(): + self.params[key] = row[key] + # TODO: dont attach to the mapper, its huge. + # figure out some way to shrink this. + self.mapper = lazyloader.mapper + + def __call__(self): + return self.mapper.select(self.lazywhere, **self.params) - setattr(instance, self.key, load) - class EagerLoader(PropertyLoader): def init(self, key, parent, root): PropertyLoader.init(self, key, parent, root) @@ -440,8 +487,7 @@ class EagerLoader(PropertyLoader): if self.secondaryjoin is not None: [self.to_alias.append(f) for f in self.secondaryjoin._get_from_objects()] del self.to_alias[parent.selectable] - - + def setup(self, key, primarytable, statement, **options): """add a left outer join to the statement thats being constructed""" @@ -565,10 +611,14 @@ def match_primaries(primary, secondary): return sql.and_([pk == secondary.c[pk.name] for pk in primary.primary_keys]) class IdentityMap(dict): + def get_id_key(self, ident, class_, table, selectable): + return (class_, table, tuple(ident)) + def get_instance_key(self, object, class_, table, selectable): + return (class_, table, tuple([getattr(object, column.key, None) for column in selectable.primary_keys])) def get_key(self, row, class_, table, selectable): return (class_, table, tuple([row[column.label] for column in selectable.primary_keys])) def hash_key(self): return "IdentityMap(%s)" % id(self) - + _global_identitymap = IdentityMap() diff --git a/test/mapper.py b/test/mapper.py index ee5a2d4591..1e2eabd926 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -7,6 +7,7 @@ class User(object): def __repr__(self): return ( """ +objid: %d User ID: %s User Name: %s Addresses: %s @@ -14,7 +15,7 @@ Orders: %s Open Orders %s Closed Orderss %s ------------------ -""" % tuple([self.user_id, repr(self.user_name)] + [repr(getattr(self, attr, None)) for attr in ('addresses', 'orders', 'orders_open', 'orders_closed')]) +""" % tuple([id(self), self.user_id, repr(self.user_name)] + [repr(getattr(self, attr, None)) for attr in ('addresses', 'orders', 'orders_open', 'orders_closed')]) ) class Address(object): @@ -52,7 +53,14 @@ class MapperTest(AssertMixin): def setUp(self): globalidentity().clear() - + + def testget(self): + m = mapper(User, users, echo = True) + self.assert_(m.get(19) is None) + u = m.get(7) + u2 = m.get(7) + self.assert_(u is u2) + def testload(self): """tests loading rows with a mapper and producing object instances""" m = mapper(User, users) @@ -67,7 +75,7 @@ class MapperTest(AssertMixin): addresses = relation(Address, addresses, lazy = True) ), echo = True) l = m.options(eagerload('addresses')).select() - self.assert_result(l, User, + self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}])}, {'user_id' : 9, 'addresses' : (Address, [])} @@ -79,7 +87,7 @@ class MapperTest(AssertMixin): addresses = relation(Address, addresses, lazy = False) ), echo = True) l = m.options(lazyload('addresses')).select() - self.assert_result(l, User, + self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}])}, {'user_id' : 9, 'addresses' : (Address, [])} @@ -216,13 +224,14 @@ class EagerTest(PersistTest): print repr(l) class SaveTest(PersistTest): - def _testinsert(self): + + def testinsert(self): u = User() u.user_name = 'inserttester' - m = mapper(User, users) + m = mapper(User, users, echo=True) m.insert(u) - - nu = m.select(users.c.user_id == u.user_id) + nu = m.get(u.user_id) + # nu = m.select(users.c.user_id == u.user_id)[0] self.assert_(u is nu) if __name__ == "__main__":