From: Mike Bayer Date: Sun, 28 Aug 2005 20:44:41 +0000 (+0000) Subject: (no commit message) X-Git-Tag: rel_0_1_0~826 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c4b7f5bd9948e55f08c09b3f855bf26257919f88;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git --- diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 7689c53e82..9e662f455a 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -132,7 +132,7 @@ class SQLEngine(schema.SchemaEngine): if echo is True or self._echo: self.log(statement) - self.log("here are the params: " + repr(parameters)) + self.log(repr(parameters)) if connection is None: poolconn = self.connection() @@ -162,6 +162,14 @@ class ResultProxy: self.props[i] = i i+=1 + def fetchall(self): + l = [] + while True: + v = self.fetchone() + if v is None: + return l + l.append(v) + def fetchone(self): row = self.cursor.fetchone() if row is not None: @@ -174,5 +182,7 @@ class RowProxy: def __init__(self, parent, row): self.parent = parent self.row = row + def __repr__(self): + return repr(self.row) def __getitem__(self, key): return self.row[self.parent.props[key]] diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index c5d2e8c193..283bfb97bc 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -73,10 +73,17 @@ def lazyload(name): class Mapper(object): def __init__(self, class_, selectable, table = None, properties = None, identitymap = None, use_smart_properties = True, isroot = True, echo = None): self.class_ = class_ - self.selectable = selectable self.use_smart_properties = use_smart_properties + + self.selectable = selectable + tf = Mapper.TableFinder() + self.selectable.accept_visitor(tf) + self.tables = tf.tables + if table is None: - self.table = self._find_table(selectable) + if len(self.tables) > 1: + raise "Selectable contains multiple tables - specify primary table argument to Mapper" + self.table = self.tables[0] else: self.table = table @@ -141,7 +148,7 @@ class Mapper(object): except KeyError: clause = sql.and_() i = 0 - for primary_key in self.selectable.primary_keys: + for primary_key in self.table.primary_keys: # appending to the and_'s clause list directly to skip # typechecks etc. clause.clauses.append(primary_key == ident[i]) @@ -190,7 +197,7 @@ class Mapper(object): else: return self._select_whereclause(arg, **params) - def save(self, object, traverse = True, refetch = False): + def save(self, obj, traverse = True, refetch = False): """saves the object. based on the existence of its primary key, either inserts or updates. primary key is determined by the underlying database engine's sequence methodology. traverse indicates attached objects should be saved as well. @@ -199,31 +206,44 @@ class Mapper(object): of the attribute, determines if the item is saved. if smart attributes are not being used, the item is saved unconditionally. """ - if getattr(object, 'dirty', True): - pass - # do the save + # TODO: support multi-table saves + if getattr(obj, 'dirty', True): + for table in self.tables: + for col in table.columns: + if getattr(obj, col.key, None) is None: + self.insert(obj, table) + break + else: + self.update(obj, table) + for prop in self.props.values(): - prop.save(object, traverse, refetch) - - def remove(self, object, traverse = True): + prop.save(obj, traverse, refetch) + + def remove(self, obj, traverse = True): """removes the object. traverse indicates attached objects should be removed as well.""" pass - - def insert(self, obj): - """inserts the object into its table, regardless of primary key being set. this is a + + def insert(self, obj, table = None): + """inserts an object into one table, regardless of primary key being set. this is a lower-level operation than save.""" + + if table is None: + table = self.table + params = {} - for col in self.table.columns: + for col in table.columns: params[col.key] = getattr(obj, col.key, None) - ins = self.table.insert() + ins = table.insert() + ins.echo = self.echo ins.execute(**params) - # TODO: unset dirty flag + # unset dirty flag + obj.dirty = False # populate new primary keys - primary_keys = self.table.engine.last_inserted_ids() + primary_keys = table.engine.last_inserted_ids() index = 0 - for pk in self.table.primary_keys: + for pk in table.primary_keys: newid = primary_keys[index] index += 1 # TODO: do this via the ColumnProperty objects @@ -231,15 +251,24 @@ class Mapper(object): self.put(obj) - def update(self, obj): - """inserts the object into its table, regardless of primary key being set. this is a + def update(self, obj, table = None): + """updates an object in one table, regardless of primary key being set. this is a lower-level operation than save.""" + + if table is None: + table = self.table params = {} - for col in self.table.columns: - params[col.key] = getattr(obj, col.key) - upd = self.table.update() + clause = sql.and_() + for col in table.columns: + if col.primary_key: + clause.clauses.append(col == getattr(obj, col.key)) + else: + params[col.key] = getattr(obj, col.key) + upd = table.update(clause) + upd.echo = self.echo upd.execute(**params) - # TODO: unset dirty flag + # unset dirty flag + obj.dirty = False def delete(self, obj): """deletes the object's row from its table unconditionally. this is a lower-level @@ -251,15 +280,11 @@ class Mapper(object): pass class TableFinder(sql.ClauseVisitor): + def __init__(self): + self.tables = [] def visit_table(self, table): - if hasattr(self, 'table'): - raise "Mapper can only create object instances against a single-table identity - specify the 'table' argument to the Mapper constructor" - self.table = table - - def _find_table(self, selectable): - tf = Mapper.TableFinder() - selectable.accept_visitor(tf) - return tf.table + self.tables.append(table) + def _compile(self, whereclause = None, **options): statement = sql.select([self.selectable], whereclause) @@ -267,7 +292,7 @@ class Mapper(object): value.setup(key, self.selectable, statement, **options) statement.use_labels = True return statement - + def _select_whereclause(self, whereclause = None, **params): statement = self._compile(whereclause) return self._select_statement(statement, **params) @@ -280,7 +305,6 @@ class Mapper(object): def _identity_key(self, row): return self.identitymap.get_key(row, self.class_, self.table, self.selectable) - def _instance(self, row, localmap, result): """pulls an object instance from the given row and appends it to the given result list. if the instance already exists in the given identity map, its not added. in either @@ -293,7 +317,7 @@ class Mapper(object): exists = self.identitymap.has_key(identitykey) if not exists: instance = self.class_() - for column in self.selectable.primary_keys: + for column in self.table.primary_keys: if row[column.label] is None: return None self.identitymap[identitykey] = instance @@ -308,7 +332,6 @@ class Mapper(object): imap = localmap[id(result)] except KeyError: imap = localmap.setdefault(id(result), IdentityMap()) - isduplicate = imap.has_key(identitykey) if not isduplicate: imap[identitykey] = instance @@ -325,7 +348,6 @@ class MapperOption: of it. This is used to assist in the prototype pattern used by mapper.options().""" def process(self, mapper): raise NotImplementedError() - def hash_key(self): return repr(self) @@ -386,6 +408,8 @@ class ColumnProperty(MapperProperty): class PropertyLoader(MapperProperty): + """describes an object property that holds a list of items that correspond to a related + database table.""" def __init__(self, mapper, secondary, primaryjoin, secondaryjoin): self.mapper = mapper self.target = self.mapper.selectable @@ -484,6 +508,7 @@ class LazyLoadInstance(object): return self.mapper.select(self.lazywhere, **self.params) class EagerLoader(PropertyLoader): + """loads related objects inline with a parent query.""" def init(self, key, parent, root): PropertyLoader.init(self, key, parent, root) self.to_alias = util.Set() @@ -504,22 +529,22 @@ class EagerLoader(PropertyLoader): aliasizer = Aliasizer(target, "aliased_" + target.name + "_" + hex(random.randint(0, 65535))[2:]) statement.whereclause.accept_visitor(aliasizer) statement.append_from(aliasizer.alias) - + if hasattr(statement, '_outerjoin'): towrap = statement._outerjoin else: towrap = primarytable - + if self.secondaryjoin is not None: statement._outerjoin = sql.outerjoin(sql.outerjoin(towrap, self.secondary, self.secondaryjoin), self.target, self.primaryjoin) else: statement._outerjoin = sql.outerjoin(towrap, self.target, self.primaryjoin) - + statement.append_from(statement._outerjoin) statement.append_column(self.target) for key, value in self.mapper.props.iteritems(): value.setup(key, self.mapper.selectable, statement) - + def execute(self, instance, row, identitykey, localmap, isduplicate): """receive a row. tell our mapper to look for a new object instance in the row, and attach it to a list on the parent instance.""" @@ -561,15 +586,13 @@ class Aliasizer(sql.ClauseVisitor): if isinstance(binary.right, schema.Column) and binary.right.table == self.table: binary.right = self.alias.c[binary.right.name] - class LazyRow(MapperProperty): + """TODO: this will lazy-load additional properties of an object from a secondary table.""" def __init__(self, table, whereclause, **options): self.table = table self.whereclause = whereclause - def init(self, key, parent, root): self.keys.append(key) - def execute(self, instance, row, identitykey, localmap, isduplicate): pass @@ -599,9 +622,9 @@ class IdentityMap(dict): def get_id_key(self, ident, class_, table, selectable): return (class_, table, tuple(ident)) def get_instance_key(self, object, class_, table, selectable): - return (class_, table, tuple([getattr(object, column.key, None) for column in selectable.primary_keys])) + return (class_, table, tuple([getattr(object, column.key, None) for column in table.primary_keys])) def get_key(self, row, class_, table, selectable): - return (class_, table, tuple([row[column.label] for column in selectable.primary_keys])) + return (class_, table, tuple([row[column.label] for column in table.primary_keys])) def hash_key(self): return "IdentityMap(%s)" % id(self) diff --git a/test/mapper.py b/test/mapper.py index c44d8dcfc5..81dcb8e53b 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -10,12 +10,13 @@ class User(object): objid: %d User ID: %s User Name: %s +email address ?: %s Addresses: %s Orders: %s Open Orders %s Closed Orderss %s ------------------ -""" % tuple([id(self), self.user_id, repr(self.user_name)] + [repr(getattr(self, attr, None)) for attr in ('addresses', 'orders', 'orders_open', 'orders_closed')]) +""" % tuple([id(self), self.user_id, repr(self.user_name), repr(getattr(self, 'email_address', None))] + [repr(getattr(self, attr, None)) for attr in ('addresses', 'orders', 'orders_open', 'orders_closed')]) ) class Address(object): @@ -69,6 +70,12 @@ class MapperTest(AssertMixin): l = m.select(users.c.user_name.endswith('ed')) self.assert_result(l, User, {'user_id' : 8}, {'user_id' : 9}) + def testmultitable(self): + usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id) + m = mapper(User, usersaddresses, table = users) + l = m.select() + print repr(l) + def testeageroptions(self): """tests that a lazy relation can be upgraded to an eager relation via the options method""" m = mapper(User, users, properties = dict( @@ -103,7 +110,7 @@ class LazyTest(AssertMixin): addresses = relation(Address, addresses, lazy = True) ), echo = True) l = m.select(users.c.user_id == 7) - self.assert_result(l, User, + self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, ) @@ -117,10 +124,10 @@ class LazyTest(AssertMixin): l = m.select() self.assert_result(l, Item, {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])}, - {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 3}, {'keyword_id' : 4}, {'keyword_id' : 6}])}, {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])}, - {'item_id' : 5, 'keywords' : (Keyword, [])}, - {'item_id' : 4, 'keywords' : (Keyword, [])} + {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 3}, {'keyword_id' : 4}, {'keyword_id' : 6}])}, + {'item_id' : 4, 'keywords' : (Keyword, [])}, + {'item_id' : 5, 'keywords' : (Keyword, [])} ) l = m.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id)) @@ -230,9 +237,61 @@ class SaveTest(PersistTest): u.user_name = 'inserttester' m = mapper(User, users, echo=True) m.insert(u) +# nu = m.get(u.user_id) + nu = m.select(users.c.user_id == u.user_id)[0] + self.assert_(u is nu) + + def testsave(self): + # save two users + u = User() + u.user_name = 'savetester' + u2 = User() + u2.user_name = 'savetester2' + m = mapper(User, users, echo=True) + m.save(u) + m.save(u2) + + # assert the first one retreives the same from the identity map nu = m.get(u.user_id) - # nu = m.select(users.c.user_id == u.user_id)[0] self.assert_(u is nu) + + # clear out the identity map, so next get forces a SELECT + m.identitymap.clear() + + # check it again, identity should be different but ids the same + nu = m.get(u.user_id) + self.assert_(u is not nu and u.user_id == nu.user_id and nu.user_name == 'savetester') + + # change first users name and save + u.user_name = 'modifiedname' + m.save(u) + # select both + userlist = m.select(users.c.user_id.in_(u.user_id, u2.user_id)) + # making a slight assumption here about the IN clause mechanics with regards to ordering + self.assert_(u.user_id == userlist[0].user_id and userlist[0].user_name == 'modifiedname') + self.assert_(u2.user_id == userlist[1].user_id and userlist[1].user_name == 'savetester2') + + def testsavemultitable(self): + usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id) + m = mapper(User, usersaddresses, table = users) + u = User() + u.user_name = 'multitester' + u.email_address = 'multi@test.org' + m.save(u) + + usertable = engine.ResultProxy(users.select().execute()).fetchall() + print repr(usertable) + addresstable = engine.ResultProxy(addresses.select().execute()).fetchall() + print repr(addresstable) + + u.email_address = 'lala@hey.com' + u.user_name = 'imnew' + m.save(u) + usertable = engine.ResultProxy(users.select().execute()).fetchall() + print repr(usertable) + addresstable = engine.ResultProxy(addresses.select().execute()).fetchall() + print repr(addresstable) + if __name__ == "__main__": unittest.main() diff --git a/test/query.py b/test/query.py index af01d3191f..c92ae70e9e 100644 --- a/test/query.py +++ b/test/query.py @@ -3,7 +3,7 @@ import unittest, sys import sqlalchemy.databases.sqlite as sqllite -db = sqllite.engine('querytest.db', echo = True) +db = sqllite.engine(':memory:', {}, echo = True) from sqlalchemy.sql import * from sqlalchemy.schema import *