from exceptions import *
import mapping as mapperlib
from mapping import *
-
import sqlalchemy.schema
import sqlalchemy.ext.proxy
sqlalchemy.schema.default_engine = sqlalchemy.ext.proxy.ProxyEngine()
+from sqlalchemy.mods import install_mods
+
def global_connect(*args, **kwargs):
sqlalchemy.schema.default_engine.connect(*args, **kwargs)
\ No newline at end of file
# a dictionary mapping classes to their primary mappers
mapper_registry = weakref.WeakKeyDictionary()
+# a list of MapperExtensions that will be installed by default
+extensions = []
+
# a constant returned by _getattrbycolumn to indicate
# this mapper is not handling an attribute for a particular
# column
NO_ATTRIBUTE = object()
+# returned by a MapperExtension method to indicate a "do nothing" response
+EXT_PASS = object()
+
class Mapper(object):
"""Persists object instances to and from schema.Table objects via the sql package.
Instances of this class should be constructed through this package's mapper() or
if primarytable is not None:
sys.stderr.write("'primarytable' argument to mapper is deprecated\n")
+
+ ext = MapperExtension()
+
+ for ext_class in extensions:
+ ext = ext_class().chain(ext)
- if extension is None:
- self.extension = MapperExtension()
+ if extension is not None:
+ self.extension = extension.chain(ext)
else:
- self.extension = extension
+ self.extension = ext
+
self.class_ = class_
self.is_primary = is_primary
self.order_by = order_by
e.g. result = usermapper.select_by(user_name = 'fred')
"""
- return mapperutil.SelectResults(self, self._by_clause(*args, **params))
+ ret = self.extension.select_by(self, *args, **params)
+ if ret is not EXT_PASS:
+ return ret
+ return self.select_whereclause(self._by_clause(*args, **params))
def selectfirst_by(self, *args, **params):
"""works like select_by(), but only returns the first result by itself, or None if no
def selectone_by(self, *args, **params):
"""works like selectfirst_by(), but throws an error if not exactly one result was returned."""
- ret = list(self.select_by(*args, **params)[0:2])
+ ret = mapper.select_whereclause(self._by_clause(*args, **params), limit=2)
if len(ret) == 1:
return ret[0]
raise InvalidRequestError('Multiple rows returned for selectone_by')
return ret[0]
raise InvalidRequestError('Multiple rows returned for selectone')
- def select(self, arg = None, **kwargs):
+ def select(self, arg=None, **kwargs):
"""selects instances of the object from the database.
arg can be any ClauseElement, which will form the criterion with which to
will be executed and its resulting rowset used to build new object instances.
in this case, the developer must insure that an adequate set of columns exists in the
rowset with which to build new object instances."""
- if arg is not None and isinstance(arg, sql.Selectable):
+
+ ret = self.extension.select(self, arg=arg, **kwargs)
+ if ret is not EXT_PASS:
+ return ret
+ elif arg is not None and isinstance(arg, sql.Selectable):
return self.select_statement(arg, **kwargs)
else:
- return mapperutil.SelectResults(self, arg, ops=kwargs)
+ return self.select_whereclause(whereclause=arg, **kwargs)
def select_whereclause(self, whereclause=None, params=None, **kwargs):
statement = self._compile(whereclause, **kwargs)
imap[identitykey] = instance
for prop in self.props.values():
prop.execute(instance, row, identitykey, imap, True)
- if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing):
+ if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing) is EXT_PASS:
if result is not None:
result.append_nohistory(instance)
return instance
return None
# plugin point
instance = self.extension.create_instance(self, row, imap, self.class_)
- if instance is None:
+ if instance is EXT_PASS:
instance = self.class_(_mapper_nohistory=True)
imap[identitykey] = instance
isnew = True
# call further mapper properties on the row, to pull further
# instances from the row and possibly populate this item.
- if self.extension.populate_instance(self, instance, row, identitykey, imap, isnew):
+ if self.extension.populate_instance(self, instance, row, identitykey, imap, isnew) is EXT_PASS:
self.populate_instance(instance, row, identitykey, imap, isnew)
- if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing):
+ if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing) is EXT_PASS:
if result is not None:
result.append_nohistory(instance)
return instance
class MapperExtension(object):
def __init__(self):
self.next = None
+ def chain(self, ext):
+ self.next = ext
+ return self
+ def select_by(self, mapper, *args, **kwargs):
+ if self.next is None:
+ return EXT_PASS
+ else:
+ return self.next.select_by(mapper, *args, **kwargs)
+ def select(self, mapper, *args, **kwargs):
+ if self.next is None:
+ return EXT_PASS
+ else:
+ return self.next.select(mapper, *args, **kwargs)
def create_instance(self, mapper, row, imap, class_):
"""called when a new object instance is about to be created from a row.
the method can choose to create the instance itself, or it can return
class_ - the class we are mapping.
"""
if self.next is None:
- return None
+ return EXT_PASS
else:
return self.next.create_instance(mapper, row, imap, class_)
def append_result(self, mapper, row, imap, result, instance, isnew, populate_existing=False):
identity map, i.e. were loaded by a previous select(), get their attributes overwritten
"""
if self.next is None:
- return True
+ return EXT_PASS
else:
return self.next.append_result(mapper, row, imap, result, instance, isnew, populate_existing)
def populate_instance(self, mapper, instance, row, identitykey, imap, isnew):
def populate_instance(self, mapper, instance, row, identitykey, imap, isnew):
othermapper.populate_instance(instance, row, identitykey, imap, isnew, frommapper=mapper)
- return False
+ return True
"""
if self.next is None:
- return True
+ return EXT_PASS
else:
return self.next.populate_instance(row, imap, result, instance, isnew)
def before_insert(self, mapper, instance):
else:
uowcommit.register_dependency(self.mapper, self.parent)
uowcommit.register_processor(self.mapper, self, self.parent, False)
- # this dependency processor is used to locate "private" child objects
- # during a "delete" operation, when the objectstore is being committed
- # with only a partial list of objects
uowcommit.register_processor(self.mapper, self, self.parent, True)
else:
raise AssertionError(" no foreign key ?")
order_by = self.secondary.default_order_by()
else:
order_by = False
- result = list(self.mapper.select(self.lazywhere, order_by=order_by, params=params))
+ result = self.mapper.select_whereclause(self.lazywhere, order_by=order_by, params=params)
else:
result = []
if self.uselist:
import sqlalchemy.sql as sql
-class SelectResults(object):
- def __init__(self, mapper, clause=None, ops={}):
- self._mapper = mapper
- self._clause = clause
- self._ops = {}
- self._ops.update(ops)
-
- def count(self):
- return self._mapper.count(self._clause)
-
- def min(self, col):
- return sql.select([sql.func.min(col)], self._clause, **self._ops).scalar()
-
- def max(self, col):
- return sql.select([sql.func.max(col)], self._clause, **self._ops).scalar()
-
- def sum(self, col):
- return sql.select([sql.func.sum(col)], self._clause, **self._ops).scalar()
-
- def avg(self, col):
- return sql.select([sql.func.avg(col)], self._clause, **self._ops).scalar()
-
- def clone(self):
- return SelectResults(self._mapper, self._clause, self._ops.copy())
-
- def filter(self, clause):
- new = self.clone()
- new._clause = sql.and_(self._clause, clause)
- return new
-
- def order_by(self, order_by):
- new = self.clone()
- new._ops['order_by'] = order_by
- return new
-
- def limit(self, limit):
- return self[:limit]
-
- def offset(self, offset):
- return self[offset:]
-
- def list(self):
- return list(self)
-
- def __getitem__(self, item):
- if isinstance(item, slice):
- start = item.start
- stop = item.stop
- if (isinstance(start, int) and start < 0) or \
- (isinstance(stop, int) and stop < 0):
- return list(self)[item]
- else:
- res = self.clone()
- if start is not None and stop is not None:
- res._ops.update(dict(offset=start, limit=stop-start))
- elif start is None and stop is not None:
- res._ops.update(dict(limit=stop))
- elif start is not None and stop is None:
- res._ops.update(dict(offset=start))
- if item.step is not None:
- return list(res)[None:None:item.step]
- else:
- return res
- else:
- return list(self[item:item+1])[0]
-
- def __iter__(self):
- return iter(self._mapper.select_whereclause(self._clause, **self._ops))
-
-
class TableFinder(sql.ClauseVisitor):
"""given a Clause, locates all the Tables within it into a list."""
def __init__(self, table, check_columns=False):
--- /dev/null
+def install_mods(*mods):
+ for mod in mods:
+ mod.install_plugin()
\ No newline at end of file
--- /dev/null
+import sqlalchemy.sql as sql
+
+import sqlalchemy.mapping as mapping
+
+def install_plugin():
+ mapping.extensions.append(SelectResultsExt)
+
+class SelectResultsExt(mapping.MapperExtension):
+ def select_by(self, mapper, *args, **params):
+ return SelectResults(mapper, mapper._by_clause(*args, **params))
+ def select(self, mapper, arg=None, **kwargs):
+ if arg is not None and isinstance(arg, sql.Selectable):
+ return mapping.EXT_PASS
+ else:
+ return SelectResults(mapper, arg, ops=kwargs)
+
+class SelectResults(object):
+ def __init__(self, mapper, clause=None, ops={}):
+ self._mapper = mapper
+ self._clause = clause
+ self._ops = {}
+ self._ops.update(ops)
+
+ def count(self):
+ return self._mapper.count(self._clause)
+
+ def min(self, col):
+ return sql.select([sql.func.min(col)], self._clause, **self._ops).scalar()
+
+ def max(self, col):
+ return sql.select([sql.func.max(col)], self._clause, **self._ops).scalar()
+
+ def sum(self, col):
+ return sql.select([sql.func.sum(col)], self._clause, **self._ops).scalar()
+
+ def avg(self, col):
+ return sql.select([sql.func.avg(col)], self._clause, **self._ops).scalar()
+
+ def clone(self):
+ return SelectResults(self._mapper, self._clause, self._ops.copy())
+
+ def filter(self, clause):
+ new = self.clone()
+ new._clause = sql.and_(self._clause, clause)
+ return new
+
+ def order_by(self, order_by):
+ new = self.clone()
+ new._ops['order_by'] = order_by
+ return new
+
+ def limit(self, limit):
+ return self[:limit]
+
+ def offset(self, offset):
+ return self[offset:]
+
+ def list(self):
+ return list(self)
+
+ def __getitem__(self, item):
+ if isinstance(item, slice):
+ start = item.start
+ stop = item.stop
+ if (isinstance(start, int) and start < 0) or \
+ (isinstance(stop, int) and stop < 0):
+ return list(self)[item]
+ else:
+ res = self.clone()
+ if start is not None and stop is not None:
+ res._ops.update(dict(offset=start, limit=stop-start))
+ elif start is None and stop is not None:
+ res._ops.update(dict(limit=stop))
+ elif start is not None and stop is None:
+ res._ops.update(dict(offset=start))
+ if item.step is not None:
+ return list(res)[None:None:item.step]
+ else:
+ return res
+ else:
+ return list(self[item:item+1])[0]
+
+ def __iter__(self):
+ return iter(self._mapper.select_whereclause(self._clause, **self._ops))
# object isnt refreshed yet, using dict to bypass trigger
self.assert_(u.__dict__['user_name'] != 'jack')
# do a select
- m.select().list()
+ m.select()
# test that it refreshed
self.assert_(u.__dict__['user_name'] == 'jack')
m = mapper(User, users, properties = dict(
addresses = relation(mapper(Address, addresses), lazy = True)
))
- l = m.options(eagerload('addresses')).select().list()
+ l = m.options(eagerload('addresses')).select()
def go():
self.assert_result(l, User, *user_address_result)
m = mapper(User, users, properties = dict(
addresses = relation(mapper(Address, addresses), lazy = False)
))
- l = m.options(lazyload('addresses')).select().list()
+ l = m.options(lazyload('addresses')).select()
def go():
self.assert_result(l, User, *user_address_result)
self.assert_sql_count(db, go, 3)
})
m2 = m.options(eagerload('orders.items.keywords'))
- u = m.select().list()
+ u = m.select()
def go():
print u[0].orders[1].items[0].keywords[1]
self.assert_sql_count(db, go, 3)
objectstore.clear()
- u = m2.select().list()
+ u = m2.select()
self.assert_sql_count(db, go, 2)
class PropertyTest(MapperSuperTest):
self.assert_(o.description is None)
def go():
- l = m.select().list()
+ l = m.select()
o2 = l[2]
print o2.description
})
def go():
- l = m.select().list()
+ l = m.select()
o2 = l[2]
print o2.opened, o2.description, o2.userident
self.assert_sql(db, go, [
m = mapper(Order, orders)
m2 = m.options(defer('user_id'))
def go():
- l = m2.select().list()
+ l = m2.select()
print l[2].user_id
self.assert_sql(db, go, [
("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}),
objectstore.clear()
m3 = m2.options(undefer('user_id'))
def go():
- l = m3.select().list()
+ l = m3.select()
print l[3].user_id
self.assert_sql(db, go, [
("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}),
try:
trans = objectstore.begin()
- all = User.select()[:].list()
+ all = User.select()[:]
assert all == []
u = User()