From d929ae8d90174a270ee708279ff9c831f9d3193e Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 31 Mar 2006 07:20:13 +0000 Subject: [PATCH] introducing...the mods package ! the SelectResults thing moves as the first mod --- lib/sqlalchemy/__init__.py | 3 +- lib/sqlalchemy/mapping/mapper.py | 64 +++++++++++++++------ lib/sqlalchemy/mapping/properties.py | 5 +- lib/sqlalchemy/mapping/util.py | 70 ----------------------- lib/sqlalchemy/mods/__init__.py | 3 + lib/sqlalchemy/mods/selectresults.py | 84 ++++++++++++++++++++++++++++ test/mapper.py | 18 +++--- test/proxy_engine.py | 2 +- 8 files changed, 148 insertions(+), 101 deletions(-) create mode 100644 lib/sqlalchemy/mods/__init__.py create mode 100644 lib/sqlalchemy/mods/selectresults.py diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 9f7b99feab..bbb57955ba 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -11,11 +11,12 @@ from schema import * from exceptions import * import mapping as mapperlib from mapping import * - import sqlalchemy.schema import sqlalchemy.ext.proxy sqlalchemy.schema.default_engine = sqlalchemy.ext.proxy.ProxyEngine() +from sqlalchemy.mods import install_mods + def global_connect(*args, **kwargs): sqlalchemy.schema.default_engine.connect(*args, **kwargs) \ No newline at end of file diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index 7e12459c53..9d2286eb05 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -19,11 +19,17 @@ import weakref # a dictionary mapping classes to their primary mappers mapper_registry = weakref.WeakKeyDictionary() +# a list of MapperExtensions that will be installed by default +extensions = [] + # a constant returned by _getattrbycolumn to indicate # this mapper is not handling an attribute for a particular # column NO_ATTRIBUTE = object() +# returned by a MapperExtension method to indicate a "do nothing" response +EXT_PASS = object() + class Mapper(object): """Persists object instances to and from schema.Table objects via the sql package. Instances of this class should be constructed through this package's mapper() or @@ -47,11 +53,17 @@ class Mapper(object): if primarytable is not None: sys.stderr.write("'primarytable' argument to mapper is deprecated\n") + + ext = MapperExtension() + + for ext_class in extensions: + ext = ext_class().chain(ext) - if extension is None: - self.extension = MapperExtension() + if extension is not None: + self.extension = extension.chain(ext) else: - self.extension = extension + self.extension = ext + self.class_ = class_ self.is_primary = is_primary self.order_by = order_by @@ -425,7 +437,10 @@ class Mapper(object): e.g. result = usermapper.select_by(user_name = 'fred') """ - return mapperutil.SelectResults(self, self._by_clause(*args, **params)) + ret = self.extension.select_by(self, *args, **params) + if ret is not EXT_PASS: + return ret + return self.select_whereclause(self._by_clause(*args, **params)) def selectfirst_by(self, *args, **params): """works like select_by(), but only returns the first result by itself, or None if no @@ -434,7 +449,7 @@ class Mapper(object): def selectone_by(self, *args, **params): """works like selectfirst_by(), but throws an error if not exactly one result was returned.""" - ret = list(self.select_by(*args, **params)[0:2]) + ret = mapper.select_whereclause(self._by_clause(*args, **params), limit=2) if len(ret) == 1: return ret[0] raise InvalidRequestError('Multiple rows returned for selectone_by') @@ -510,7 +525,7 @@ class Mapper(object): return ret[0] raise InvalidRequestError('Multiple rows returned for selectone') - def select(self, arg = None, **kwargs): + def select(self, arg=None, **kwargs): """selects instances of the object from the database. arg can be any ClauseElement, which will form the criterion with which to @@ -520,10 +535,14 @@ class Mapper(object): will be executed and its resulting rowset used to build new object instances. in this case, the developer must insure that an adequate set of columns exists in the rowset with which to build new object instances.""" - if arg is not None and isinstance(arg, sql.Selectable): + + ret = self.extension.select(self, arg=arg, **kwargs) + if ret is not EXT_PASS: + return ret + elif arg is not None and isinstance(arg, sql.Selectable): return self.select_statement(arg, **kwargs) else: - return mapperutil.SelectResults(self, arg, ops=kwargs) + return self.select_whereclause(whereclause=arg, **kwargs) def select_whereclause(self, whereclause=None, params=None, **kwargs): statement = self._compile(whereclause, **kwargs) @@ -850,7 +869,7 @@ class Mapper(object): imap[identitykey] = instance for prop in self.props.values(): prop.execute(instance, row, identitykey, imap, True) - if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing): + if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing) is EXT_PASS: if result is not None: result.append_nohistory(instance) return instance @@ -865,7 +884,7 @@ class Mapper(object): return None # plugin point instance = self.extension.create_instance(self, row, imap, self.class_) - if instance is None: + if instance is EXT_PASS: instance = self.class_(_mapper_nohistory=True) imap[identitykey] = instance isnew = True @@ -877,9 +896,9 @@ class Mapper(object): # call further mapper properties on the row, to pull further # instances from the row and possibly populate this item. - if self.extension.populate_instance(self, instance, row, identitykey, imap, isnew): + if self.extension.populate_instance(self, instance, row, identitykey, imap, isnew) is EXT_PASS: self.populate_instance(instance, row, identitykey, imap, isnew) - if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing): + if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing) is EXT_PASS: if result is not None: result.append_nohistory(instance) return instance @@ -966,6 +985,19 @@ class ExtensionOption(MapperOption): class MapperExtension(object): def __init__(self): self.next = None + def chain(self, ext): + self.next = ext + return self + def select_by(self, mapper, *args, **kwargs): + if self.next is None: + return EXT_PASS + else: + return self.next.select_by(mapper, *args, **kwargs) + def select(self, mapper, *args, **kwargs): + if self.next is None: + return EXT_PASS + else: + return self.next.select(mapper, *args, **kwargs) def create_instance(self, mapper, row, imap, class_): """called when a new object instance is about to be created from a row. the method can choose to create the instance itself, or it can return @@ -981,7 +1013,7 @@ class MapperExtension(object): class_ - the class we are mapping. """ if self.next is None: - return None + return EXT_PASS else: return self.next.create_instance(mapper, row, imap, class_) def append_result(self, mapper, row, imap, result, instance, isnew, populate_existing=False): @@ -1011,7 +1043,7 @@ class MapperExtension(object): identity map, i.e. were loaded by a previous select(), get their attributes overwritten """ if self.next is None: - return True + return EXT_PASS else: return self.next.append_result(mapper, row, imap, result, instance, isnew, populate_existing) def populate_instance(self, mapper, instance, row, identitykey, imap, isnew): @@ -1024,10 +1056,10 @@ class MapperExtension(object): def populate_instance(self, mapper, instance, row, identitykey, imap, isnew): othermapper.populate_instance(instance, row, identitykey, imap, isnew, frommapper=mapper) - return False + return True """ if self.next is None: - return True + return EXT_PASS else: return self.next.populate_instance(row, imap, result, instance, isnew) def before_insert(self, mapper, instance): diff --git a/lib/sqlalchemy/mapping/properties.py b/lib/sqlalchemy/mapping/properties.py index 206905aa42..fa14dc4deb 100644 --- a/lib/sqlalchemy/mapping/properties.py +++ b/lib/sqlalchemy/mapping/properties.py @@ -382,9 +382,6 @@ class PropertyLoader(MapperProperty): else: uowcommit.register_dependency(self.mapper, self.parent) uowcommit.register_processor(self.mapper, self, self.parent, False) - # this dependency processor is used to locate "private" child objects - # during a "delete" operation, when the objectstore is being committed - # with only a partial list of objects uowcommit.register_processor(self.mapper, self, self.parent, True) else: raise AssertionError(" no foreign key ?") @@ -616,7 +613,7 @@ class LazyLoader(PropertyLoader): order_by = self.secondary.default_order_by() else: order_by = False - result = list(self.mapper.select(self.lazywhere, order_by=order_by, params=params)) + result = self.mapper.select_whereclause(self.lazywhere, order_by=order_by, params=params) else: result = [] if self.uselist: diff --git a/lib/sqlalchemy/mapping/util.py b/lib/sqlalchemy/mapping/util.py index 8e780eeef4..18e2ac0538 100644 --- a/lib/sqlalchemy/mapping/util.py +++ b/lib/sqlalchemy/mapping/util.py @@ -1,75 +1,5 @@ import sqlalchemy.sql as sql -class SelectResults(object): - def __init__(self, mapper, clause=None, ops={}): - self._mapper = mapper - self._clause = clause - self._ops = {} - self._ops.update(ops) - - def count(self): - return self._mapper.count(self._clause) - - def min(self, col): - return sql.select([sql.func.min(col)], self._clause, **self._ops).scalar() - - def max(self, col): - return sql.select([sql.func.max(col)], self._clause, **self._ops).scalar() - - def sum(self, col): - return sql.select([sql.func.sum(col)], self._clause, **self._ops).scalar() - - def avg(self, col): - return sql.select([sql.func.avg(col)], self._clause, **self._ops).scalar() - - def clone(self): - return SelectResults(self._mapper, self._clause, self._ops.copy()) - - def filter(self, clause): - new = self.clone() - new._clause = sql.and_(self._clause, clause) - return new - - def order_by(self, order_by): - new = self.clone() - new._ops['order_by'] = order_by - return new - - def limit(self, limit): - return self[:limit] - - def offset(self, offset): - return self[offset:] - - def list(self): - return list(self) - - def __getitem__(self, item): - if isinstance(item, slice): - start = item.start - stop = item.stop - if (isinstance(start, int) and start < 0) or \ - (isinstance(stop, int) and stop < 0): - return list(self)[item] - else: - res = self.clone() - if start is not None and stop is not None: - res._ops.update(dict(offset=start, limit=stop-start)) - elif start is None and stop is not None: - res._ops.update(dict(limit=stop)) - elif start is not None and stop is None: - res._ops.update(dict(offset=start)) - if item.step is not None: - return list(res)[None:None:item.step] - else: - return res - else: - return list(self[item:item+1])[0] - - def __iter__(self): - return iter(self._mapper.select_whereclause(self._clause, **self._ops)) - - class TableFinder(sql.ClauseVisitor): """given a Clause, locates all the Tables within it into a list.""" def __init__(self, table, check_columns=False): diff --git a/lib/sqlalchemy/mods/__init__.py b/lib/sqlalchemy/mods/__init__.py new file mode 100644 index 0000000000..71b6756649 --- /dev/null +++ b/lib/sqlalchemy/mods/__init__.py @@ -0,0 +1,3 @@ +def install_mods(*mods): + for mod in mods: + mod.install_plugin() \ No newline at end of file diff --git a/lib/sqlalchemy/mods/selectresults.py b/lib/sqlalchemy/mods/selectresults.py new file mode 100644 index 0000000000..9613d8fc0f --- /dev/null +++ b/lib/sqlalchemy/mods/selectresults.py @@ -0,0 +1,84 @@ +import sqlalchemy.sql as sql + +import sqlalchemy.mapping as mapping + +def install_plugin(): + mapping.extensions.append(SelectResultsExt) + +class SelectResultsExt(mapping.MapperExtension): + def select_by(self, mapper, *args, **params): + return SelectResults(mapper, mapper._by_clause(*args, **params)) + def select(self, mapper, arg=None, **kwargs): + if arg is not None and isinstance(arg, sql.Selectable): + return mapping.EXT_PASS + else: + return SelectResults(mapper, arg, ops=kwargs) + +class SelectResults(object): + def __init__(self, mapper, clause=None, ops={}): + self._mapper = mapper + self._clause = clause + self._ops = {} + self._ops.update(ops) + + def count(self): + return self._mapper.count(self._clause) + + def min(self, col): + return sql.select([sql.func.min(col)], self._clause, **self._ops).scalar() + + def max(self, col): + return sql.select([sql.func.max(col)], self._clause, **self._ops).scalar() + + def sum(self, col): + return sql.select([sql.func.sum(col)], self._clause, **self._ops).scalar() + + def avg(self, col): + return sql.select([sql.func.avg(col)], self._clause, **self._ops).scalar() + + def clone(self): + return SelectResults(self._mapper, self._clause, self._ops.copy()) + + def filter(self, clause): + new = self.clone() + new._clause = sql.and_(self._clause, clause) + return new + + def order_by(self, order_by): + new = self.clone() + new._ops['order_by'] = order_by + return new + + def limit(self, limit): + return self[:limit] + + def offset(self, offset): + return self[offset:] + + def list(self): + return list(self) + + def __getitem__(self, item): + if isinstance(item, slice): + start = item.start + stop = item.stop + if (isinstance(start, int) and start < 0) or \ + (isinstance(stop, int) and stop < 0): + return list(self)[item] + else: + res = self.clone() + if start is not None and stop is not None: + res._ops.update(dict(offset=start, limit=stop-start)) + elif start is None and stop is not None: + res._ops.update(dict(limit=stop)) + elif start is not None and stop is None: + res._ops.update(dict(offset=start)) + if item.step is not None: + return list(res)[None:None:item.step] + else: + return res + else: + return list(self[item:item+1])[0] + + def __iter__(self): + return iter(self._mapper.select_whereclause(self._clause, **self._ops)) diff --git a/test/mapper.py b/test/mapper.py index 6bfc3f3b89..4a8edd0974 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -133,7 +133,7 @@ class MapperTest(MapperSuperTest): # object isnt refreshed yet, using dict to bypass trigger self.assert_(u.__dict__['user_name'] != 'jack') # do a select - m.select().list() + m.select() # test that it refreshed self.assert_(u.__dict__['user_name'] == 'jack') @@ -255,7 +255,7 @@ class MapperTest(MapperSuperTest): m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = True) )) - l = m.options(eagerload('addresses')).select().list() + l = m.options(eagerload('addresses')).select() def go(): self.assert_result(l, User, *user_address_result) @@ -266,7 +266,7 @@ class MapperTest(MapperSuperTest): m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = False) )) - l = m.options(lazyload('addresses')).select().list() + l = m.options(lazyload('addresses')).select() def go(): self.assert_result(l, User, *user_address_result) self.assert_sql_count(db, go, 3) @@ -282,12 +282,12 @@ class MapperTest(MapperSuperTest): }) m2 = m.options(eagerload('orders.items.keywords')) - u = m.select().list() + u = m.select() def go(): print u[0].orders[1].items[0].keywords[1] self.assert_sql_count(db, go, 3) objectstore.clear() - u = m2.select().list() + u = m2.select() self.assert_sql_count(db, go, 2) class PropertyTest(MapperSuperTest): @@ -368,7 +368,7 @@ class DeferredTest(MapperSuperTest): self.assert_(o.description is None) def go(): - l = m.select().list() + l = m.select() o2 = l[2] print o2.description @@ -397,7 +397,7 @@ class DeferredTest(MapperSuperTest): }) def go(): - l = m.select().list() + l = m.select() o2 = l[2] print o2.opened, o2.description, o2.userident self.assert_sql(db, go, [ @@ -410,7 +410,7 @@ class DeferredTest(MapperSuperTest): m = mapper(Order, orders) m2 = m.options(defer('user_id')) def go(): - l = m2.select().list() + l = m2.select() print l[2].user_id self.assert_sql(db, go, [ ("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}), @@ -419,7 +419,7 @@ class DeferredTest(MapperSuperTest): objectstore.clear() m3 = m2.options(undefer('user_id')) def go(): - l = m3.select().list() + l = m3.select() print l[3].user_id self.assert_sql(db, go, [ ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}), diff --git a/test/proxy_engine.py b/test/proxy_engine.py index cd01272b5c..170e526d96 100644 --- a/test/proxy_engine.py +++ b/test/proxy_engine.py @@ -96,7 +96,7 @@ class ThreadProxyTest(PersistTest): try: trans = objectstore.begin() - all = User.select()[:].list() + all = User.select()[:] assert all == [] u = User() -- 2.47.2