class Mapper(object):
def __init__(self, class_, selectable, table = None, properties = None, identitymap = None, use_smart_properties = True, isroot = True, echo = None):
self.class_ = class_
- self.selectable = selectable
self.use_smart_properties = use_smart_properties
+
+ self.selectable = selectable
+ tf = Mapper.TableFinder()
+ self.selectable.accept_visitor(tf)
+ self.tables = tf.tables
+
if table is None:
- self.table = self._find_table(selectable)
+ if len(self.tables) > 1:
+ raise "Selectable contains multiple tables - specify primary table argument to Mapper"
+ self.table = self.tables[0]
else:
self.table = table
except KeyError:
clause = sql.and_()
i = 0
- for primary_key in self.selectable.primary_keys:
+ for primary_key in self.table.primary_keys:
# appending to the and_'s clause list directly to skip
# typechecks etc.
clause.clauses.append(primary_key == ident[i])
else:
return self._select_whereclause(arg, **params)
- def save(self, object, traverse = True, refetch = False):
+ def save(self, obj, traverse = True, refetch = False):
"""saves the object. based on the existence of its primary key, either inserts or updates.
primary key is determined by the underlying database engine's sequence methodology.
traverse indicates attached objects should be saved as well.
of the attribute, determines if the item is saved. if smart attributes are not being
used, the item is saved unconditionally.
"""
- if getattr(object, 'dirty', True):
- pass
- # do the save
+ # TODO: support multi-table saves
+ if getattr(obj, 'dirty', True):
+ for table in self.tables:
+ for col in table.columns:
+ if getattr(obj, col.key, None) is None:
+ self.insert(obj, table)
+ break
+ else:
+ self.update(obj, table)
+
for prop in self.props.values():
- prop.save(object, traverse, refetch)
-
- def remove(self, object, traverse = True):
+ prop.save(obj, traverse, refetch)
+
+ def remove(self, obj, traverse = True):
"""removes the object. traverse indicates attached objects should be removed as well."""
pass
-
- def insert(self, obj):
- """inserts the object into its table, regardless of primary key being set. this is a
+
+ def insert(self, obj, table = None):
+ """inserts an object into one table, regardless of primary key being set. this is a
lower-level operation than save."""
+
+ if table is None:
+ table = self.table
+
params = {}
- for col in self.table.columns:
+ for col in table.columns:
params[col.key] = getattr(obj, col.key, None)
- ins = self.table.insert()
+ ins = table.insert()
+ ins.echo = self.echo
ins.execute(**params)
- # TODO: unset dirty flag
+ # unset dirty flag
+ obj.dirty = False
# populate new primary keys
- primary_keys = self.table.engine.last_inserted_ids()
+ primary_keys = table.engine.last_inserted_ids()
index = 0
- for pk in self.table.primary_keys:
+ for pk in table.primary_keys:
newid = primary_keys[index]
index += 1
# TODO: do this via the ColumnProperty objects
self.put(obj)
- def update(self, obj):
- """inserts the object into its table, regardless of primary key being set. this is a
+ def update(self, obj, table = None):
+ """updates an object in one table, regardless of primary key being set. this is a
lower-level operation than save."""
+
+ if table is None:
+ table = self.table
params = {}
- for col in self.table.columns:
- params[col.key] = getattr(obj, col.key)
- upd = self.table.update()
+ clause = sql.and_()
+ for col in table.columns:
+ if col.primary_key:
+ clause.clauses.append(col == getattr(obj, col.key))
+ else:
+ params[col.key] = getattr(obj, col.key)
+ upd = table.update(clause)
+ upd.echo = self.echo
upd.execute(**params)
- # TODO: unset dirty flag
+ # unset dirty flag
+ obj.dirty = False
def delete(self, obj):
"""deletes the object's row from its table unconditionally. this is a lower-level
pass
class TableFinder(sql.ClauseVisitor):
+ def __init__(self):
+ self.tables = []
def visit_table(self, table):
- if hasattr(self, 'table'):
- raise "Mapper can only create object instances against a single-table identity - specify the 'table' argument to the Mapper constructor"
- self.table = table
-
- def _find_table(self, selectable):
- tf = Mapper.TableFinder()
- selectable.accept_visitor(tf)
- return tf.table
+ self.tables.append(table)
+
def _compile(self, whereclause = None, **options):
statement = sql.select([self.selectable], whereclause)
value.setup(key, self.selectable, statement, **options)
statement.use_labels = True
return statement
-
+
def _select_whereclause(self, whereclause = None, **params):
statement = self._compile(whereclause)
return self._select_statement(statement, **params)
def _identity_key(self, row):
return self.identitymap.get_key(row, self.class_, self.table, self.selectable)
-
def _instance(self, row, localmap, result):
"""pulls an object instance from the given row and appends it to the given result list.
if the instance already exists in the given identity map, its not added. in either
exists = self.identitymap.has_key(identitykey)
if not exists:
instance = self.class_()
- for column in self.selectable.primary_keys:
+ for column in self.table.primary_keys:
if row[column.label] is None:
return None
self.identitymap[identitykey] = instance
imap = localmap[id(result)]
except KeyError:
imap = localmap.setdefault(id(result), IdentityMap())
-
isduplicate = imap.has_key(identitykey)
if not isduplicate:
imap[identitykey] = instance
of it. This is used to assist in the prototype pattern used by mapper.options()."""
def process(self, mapper):
raise NotImplementedError()
-
def hash_key(self):
return repr(self)
class PropertyLoader(MapperProperty):
+ """describes an object property that holds a list of items that correspond to a related
+ database table."""
def __init__(self, mapper, secondary, primaryjoin, secondaryjoin):
self.mapper = mapper
self.target = self.mapper.selectable
return self.mapper.select(self.lazywhere, **self.params)
class EagerLoader(PropertyLoader):
+ """loads related objects inline with a parent query."""
def init(self, key, parent, root):
PropertyLoader.init(self, key, parent, root)
self.to_alias = util.Set()
aliasizer = Aliasizer(target, "aliased_" + target.name + "_" + hex(random.randint(0, 65535))[2:])
statement.whereclause.accept_visitor(aliasizer)
statement.append_from(aliasizer.alias)
-
+
if hasattr(statement, '_outerjoin'):
towrap = statement._outerjoin
else:
towrap = primarytable
-
+
if self.secondaryjoin is not None:
statement._outerjoin = sql.outerjoin(sql.outerjoin(towrap, self.secondary, self.secondaryjoin), self.target, self.primaryjoin)
else:
statement._outerjoin = sql.outerjoin(towrap, self.target, self.primaryjoin)
-
+
statement.append_from(statement._outerjoin)
statement.append_column(self.target)
for key, value in self.mapper.props.iteritems():
value.setup(key, self.mapper.selectable, statement)
-
+
def execute(self, instance, row, identitykey, localmap, isduplicate):
"""receive a row. tell our mapper to look for a new object instance in the row, and attach
it to a list on the parent instance."""
if isinstance(binary.right, schema.Column) and binary.right.table == self.table:
binary.right = self.alias.c[binary.right.name]
-
class LazyRow(MapperProperty):
+ """TODO: this will lazy-load additional properties of an object from a secondary table."""
def __init__(self, table, whereclause, **options):
self.table = table
self.whereclause = whereclause
-
def init(self, key, parent, root):
self.keys.append(key)
-
def execute(self, instance, row, identitykey, localmap, isduplicate):
pass
def get_id_key(self, ident, class_, table, selectable):
return (class_, table, tuple(ident))
def get_instance_key(self, object, class_, table, selectable):
- return (class_, table, tuple([getattr(object, column.key, None) for column in selectable.primary_keys]))
+ return (class_, table, tuple([getattr(object, column.key, None) for column in table.primary_keys]))
def get_key(self, row, class_, table, selectable):
- return (class_, table, tuple([row[column.label] for column in selectable.primary_keys]))
+ return (class_, table, tuple([row[column.label] for column in table.primary_keys]))
def hash_key(self):
return "IdentityMap(%s)" % id(self)
objid: %d
User ID: %s
User Name: %s
+email address ?: %s
Addresses: %s
Orders: %s
Open Orders %s
Closed Orderss %s
------------------
-""" % tuple([id(self), self.user_id, repr(self.user_name)] + [repr(getattr(self, attr, None)) for attr in ('addresses', 'orders', 'orders_open', 'orders_closed')])
+""" % tuple([id(self), self.user_id, repr(self.user_name), repr(getattr(self, 'email_address', None))] + [repr(getattr(self, attr, None)) for attr in ('addresses', 'orders', 'orders_open', 'orders_closed')])
)
class Address(object):
l = m.select(users.c.user_name.endswith('ed'))
self.assert_result(l, User, {'user_id' : 8}, {'user_id' : 9})
+ def testmultitable(self):
+ usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
+ m = mapper(User, usersaddresses, table = users)
+ l = m.select()
+ print repr(l)
+
def testeageroptions(self):
"""tests that a lazy relation can be upgraded to an eager relation via the options method"""
m = mapper(User, users, properties = dict(
addresses = relation(Address, addresses, lazy = True)
), echo = True)
l = m.select(users.c.user_id == 7)
- self.assert_result(l, User,
+ self.assert_result(l, User,
{'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
)
l = m.select()
self.assert_result(l, Item,
{'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
- {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 3}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
{'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])},
- {'item_id' : 5, 'keywords' : (Keyword, [])},
- {'item_id' : 4, 'keywords' : (Keyword, [])}
+ {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 3}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
+ {'item_id' : 4, 'keywords' : (Keyword, [])},
+ {'item_id' : 5, 'keywords' : (Keyword, [])}
)
l = m.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id))
u.user_name = 'inserttester'
m = mapper(User, users, echo=True)
m.insert(u)
+# nu = m.get(u.user_id)
+ nu = m.select(users.c.user_id == u.user_id)[0]
+ self.assert_(u is nu)
+
+ def testsave(self):
+ # save two users
+ u = User()
+ u.user_name = 'savetester'
+ u2 = User()
+ u2.user_name = 'savetester2'
+ m = mapper(User, users, echo=True)
+ m.save(u)
+ m.save(u2)
+
+ # assert the first one retreives the same from the identity map
nu = m.get(u.user_id)
- # nu = m.select(users.c.user_id == u.user_id)[0]
self.assert_(u is nu)
+
+ # clear out the identity map, so next get forces a SELECT
+ m.identitymap.clear()
+
+ # check it again, identity should be different but ids the same
+ nu = m.get(u.user_id)
+ self.assert_(u is not nu and u.user_id == nu.user_id and nu.user_name == 'savetester')
+
+ # change first users name and save
+ u.user_name = 'modifiedname'
+ m.save(u)
+ # select both
+ userlist = m.select(users.c.user_id.in_(u.user_id, u2.user_id))
+ # making a slight assumption here about the IN clause mechanics with regards to ordering
+ self.assert_(u.user_id == userlist[0].user_id and userlist[0].user_name == 'modifiedname')
+ self.assert_(u2.user_id == userlist[1].user_id and userlist[1].user_name == 'savetester2')
+
+ def testsavemultitable(self):
+ usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
+ m = mapper(User, usersaddresses, table = users)
+ u = User()
+ u.user_name = 'multitester'
+ u.email_address = 'multi@test.org'
+ m.save(u)
+
+ usertable = engine.ResultProxy(users.select().execute()).fetchall()
+ print repr(usertable)
+ addresstable = engine.ResultProxy(addresses.select().execute()).fetchall()
+ print repr(addresstable)
+
+ u.email_address = 'lala@hey.com'
+ u.user_name = 'imnew'
+ m.save(u)
+ usertable = engine.ResultProxy(users.select().execute()).fetchall()
+ print repr(usertable)
+ addresstable = engine.ResultProxy(addresses.select().execute()).fetchall()
+ print repr(addresstable)
+
if __name__ == "__main__":
unittest.main()