From 370cdcf8eec0eb9297ca9d7d79f01b397419e2a2 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 17 Jul 2005 00:45:36 +0000 Subject: [PATCH] --- lib/sqlalchemy/mapper.py | 57 ++++++++++++++++++++++++---------------- lib/sqlalchemy/sql.py | 1 - test/mapper.py | 24 +++++++++++++---- 3 files changed, 53 insertions(+), 29 deletions(-) diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 36f1317453..99bd7d455e 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -59,18 +59,19 @@ class Mapper(object): def instances(self, cursor): result = [] cursor = ResultProxy(cursor) - lastinstance = None + localmap = IdentityMap() while True: row = cursor.fetchone() if row is None: break - instance = self._instance(row) - if instance != lastinstance: + + identitykey = localmap.get_key(row, self.class_, self.table) + if not localmap.map.has_key(identitykey): + instance = self._create(row, identitykey, localmap) result.append(instance) - lastinstance = instance else: for key, prop in self.props.iteritems(): - prop.execute(instance, key, row, True) + prop.execute(instance, key, row, identitykey, localmap, True) return result @@ -110,13 +111,18 @@ class Mapper(object): statement.use_labels = True return self.instances(statement.execute(**params)) - def _instance(self, row): - return self.identitymap.get(row, self.class_, self.table, creator = self._create) + def _identity_key(self, row): + return self.identitymap.get_key(row, self.class_, self.table) - def _create(self, row): + def _create(self, row, identitykey, localmap): instance = self.class_() + for column in self.table.primary_keys: + if row[column.label] is None: + return None for key, prop in self.props.iteritems(): - prop.execute(instance, key, row, False) + prop.execute(instance, key, row, identitykey, localmap, False) + self.identitymap.map[identitykey] = instance + localmap.map[identitykey] = instance return instance @@ -130,7 +136,7 @@ class ColumnProperty(MapperProperty): def __init__(self, column): self.column = column - def execute(self, instance, key, row, isduplicate): + def execute(self, instance, key, row, identitykey, localmap, isduplicate): if not isduplicate: setattr(instance, key, row[self.column.label]) @@ -138,6 +144,7 @@ class EagerLoader(MapperProperty): def __init__(self, mapper, whereclause): self.mapper = mapper self.whereclause = whereclause + def setup(self, key, primarytable, statement): targettable = self.mapper.table if hasattr(statement, '_outerjoin'): @@ -146,14 +153,17 @@ class EagerLoader(MapperProperty): statement._outerjoin = sql.outerjoin(primarytable, targettable, self.whereclause) statement.append_from(statement._outerjoin) statement.append_column(targettable) - def execute(self, instance, key, row, isduplicate): + + def execute(self, instance, key, row, identitykey, localmap, isduplicate): try: list = getattr(instance, key) except AttributeError: list = [] setattr(instance, key, list) - subinstance = self.mapper._instance(row) - if subinstance is not None: + + identitykey = self.mapper._identity_key(row) + if not localmap.has_key(identitykey): + subinstance = self.mapper._create(row, identitykey, localmap) list.append(subinstance) class ResultProxy: @@ -186,22 +196,23 @@ class IdentityMap(object): def __init__(self): self.map = {} self.keystereotypes = {} + + def has_key(self, key): + return self.map.has_key(key) + + def get_key(self, row, class_, table): + return (class_, table.id, tuple([row[column.label] for column in table.primary_keys])) - def get(self, row, class_, table, creator = None): + def get(self, row, class_, table, key = None): """given a database row, a class to be instantiated, and a table corresponding to the row, returns a corrseponding object instance, if any, from the identity map. the primary keys specified in the table will be used to indicate which columns from the row form the effective key of the instance.""" - key = (class_, table, tuple([row[column.label] for column in table.primary_keys])) - try: - return self.map[key] - except KeyError: - newinstance = creator(row) - for column in table.primary_keys: - if row[column.label] is None: - return None - return self.map.setdefault(key, newinstance) + if key is None: + key = self.get_key(row, class_, table) + + return self.map[key] diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 071a6185ce..7438408d69 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -296,7 +296,6 @@ class Join(Selectable): def _get_from_objects(self): result = [self] + [FromClause(from_key = c.id) for c in self.left._get_from_objects() + self.right._get_from_objects()] - print repr([c.id for c in result]) return result class Alias(Selectable): diff --git a/test/mapper.py b/test/mapper.py index 23d01baac4..3d0ee0b8ac 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -12,9 +12,17 @@ import sqlalchemy.mapper as mapper class User: def __repr__(self): - return ("User: " + repr(self.user_id) + " " + self.user_name + repr(getattr(self, 'addresses', None)) + - repr(getattr(self, 'orders', None)) - ) + return ( +""" +User ID: %s +Addresses: %s +Orders: %s +Open Orders %s +Closed Orders %s +------------------ +""" % tuple([self.user_id] + [repr(getattr(self, attr, None)) for attr in ('addresses', 'orders', 'orders_open', 'orders_closed')]) +) + class Address: def __repr__(self): @@ -67,7 +75,6 @@ class MapperTest(PersistTest): m = mapper.Mapper(User, self.users) l = m.select() print repr(l) - print repr(m.identitymap.map) def testeager(self): m = mapper.Mapper(User, self.users, properties = dict( @@ -77,12 +84,19 @@ class MapperTest(PersistTest): print repr(l) def testmultieager(self): + m = mapper.Mapper(User, self.users, properties = dict( + addresses = mapper.EagerLoader(mapper.Mapper(Address, self.addresses), self.users.c.user_id==self.addresses.c.user_id), + orders = mapper.EagerLoader(mapper.Mapper(Order, self.orders), and_(self.orders.c.isopen == 1, self.users.c.user_id==self.orders.c.user_id)), + ), identitymap = mapper.IdentityMap()) + l = m.select() + print repr(l) +# return openorders = alias(self.orders, 'openorders') closedorders = alias(self.orders, 'closedorders') m = mapper.Mapper(User, self.users, properties = dict( orders_open = mapper.EagerLoader(mapper.Mapper(Order, openorders), and_(openorders.c.isopen == 1, self.users.c.user_id==openorders.c.user_id)), orders_closed = mapper.EagerLoader(mapper.Mapper(Order, closedorders), and_(closedorders.c.isopen == 0, self.users.c.user_id==closedorders.c.user_id)) - )) + ), identitymap = mapper.IdentityMap()) l = m.select() print repr(l) -- 2.47.2