]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
introducing...the mods package ! the SelectResults thing moves as the first mod
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 31 Mar 2006 07:20:13 +0000 (07:20 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 31 Mar 2006 07:20:13 +0000 (07:20 +0000)
lib/sqlalchemy/__init__.py
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/mapping/properties.py
lib/sqlalchemy/mapping/util.py
lib/sqlalchemy/mods/__init__.py [new file with mode: 0644]
lib/sqlalchemy/mods/selectresults.py [new file with mode: 0644]
test/mapper.py
test/proxy_engine.py

index 9f7b99feab8db0adda5daa507f6e9ec601da4bef..bbb57955badb086d40b5cfab70163d021f1c920b 100644 (file)
@@ -11,11 +11,12 @@ from schema import *
 from exceptions import *
 import mapping as mapperlib
 from mapping import *
-
 import sqlalchemy.schema
 import sqlalchemy.ext.proxy
 sqlalchemy.schema.default_engine = sqlalchemy.ext.proxy.ProxyEngine()
 
+from sqlalchemy.mods import install_mods
+
 def global_connect(*args, **kwargs):
     sqlalchemy.schema.default_engine.connect(*args, **kwargs)
     
\ No newline at end of file
index 7e12459c53763a0ab63391348de37ab84b8a4dc0..9d2286eb0545db6207eb3130a35d5e6ae42fc778 100644 (file)
@@ -19,11 +19,17 @@ import weakref
 # a dictionary mapping classes to their primary mappers
 mapper_registry = weakref.WeakKeyDictionary()
 
+# a list of MapperExtensions that will be installed by default
+extensions = []
+
 # a constant returned by _getattrbycolumn to indicate
 # this mapper is not handling an attribute for a particular
 # column
 NO_ATTRIBUTE = object()
 
+# returned by a MapperExtension method to indicate a "do nothing" response
+EXT_PASS = object()
+
 class Mapper(object):
     """Persists object instances to and from schema.Table objects via the sql package.
     Instances of this class should be constructed through this package's mapper() or
@@ -47,11 +53,17 @@ class Mapper(object):
 
         if primarytable is not None:
             sys.stderr.write("'primarytable' argument to mapper is deprecated\n")
+        
+        ext = MapperExtension()
+        
+        for ext_class in extensions:
+            ext = ext_class().chain(ext)
             
-        if extension is None:
-            self.extension = MapperExtension()
+        if extension is not None:
+            self.extension = extension.chain(ext)
         else:
-            self.extension = extension                
+            self.extension = ext
+
         self.class_ = class_
         self.is_primary = is_primary
         self.order_by = order_by
@@ -425,7 +437,10 @@ class Mapper(object):
         
         e.g.   result = usermapper.select_by(user_name = 'fred')
         """
-        return mapperutil.SelectResults(self, self._by_clause(*args, **params))
+        ret = self.extension.select_by(self, *args, **params)
+        if ret is not EXT_PASS:
+            return ret
+        return self.select_whereclause(self._by_clause(*args, **params))
     
     def selectfirst_by(self, *args, **params):
         """works like select_by(), but only returns the first result by itself, or None if no 
@@ -434,7 +449,7 @@ class Mapper(object):
 
     def selectone_by(self, *args, **params):
         """works like selectfirst_by(), but throws an error if not exactly one result was returned."""
-        ret = list(self.select_by(*args, **params)[0:2])
+        ret = mapper.select_whereclause(self._by_clause(*args, **params), limit=2)
         if len(ret) == 1:
             return ret[0]
         raise InvalidRequestError('Multiple rows returned for selectone_by')
@@ -510,7 +525,7 @@ class Mapper(object):
             return ret[0]
         raise InvalidRequestError('Multiple rows returned for selectone')
             
-    def select(self, arg = None, **kwargs):
+    def select(self, arg=None, **kwargs):
         """selects instances of the object from the database.  
         
         arg can be any ClauseElement, which will form the criterion with which to
@@ -520,10 +535,14 @@ class Mapper(object):
         will be executed and its resulting rowset used to build new object instances.  
         in this case, the developer must insure that an adequate set of columns exists in the 
         rowset with which to build new object instances."""
-        if arg is not None and isinstance(arg, sql.Selectable):
+
+        ret = self.extension.select(self, arg=arg, **kwargs)
+        if ret is not EXT_PASS:
+            return ret
+        elif arg is not None and isinstance(arg, sql.Selectable):
             return self.select_statement(arg, **kwargs)
         else:
-            return mapperutil.SelectResults(self, arg, ops=kwargs)
+            return self.select_whereclause(whereclause=arg, **kwargs)
 
     def select_whereclause(self, whereclause=None, params=None, **kwargs):
         statement = self._compile(whereclause, **kwargs)
@@ -850,7 +869,7 @@ class Mapper(object):
                     imap[identitykey] = instance
                 for prop in self.props.values():
                     prop.execute(instance, row, identitykey, imap, True)
-            if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing):
+            if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing) is EXT_PASS:
                 if result is not None:
                     result.append_nohistory(instance)
             return instance
@@ -865,7 +884,7 @@ class Mapper(object):
                     return None
             # plugin point
             instance = self.extension.create_instance(self, row, imap, self.class_)
-            if instance is None:
+            if instance is EXT_PASS:
                 instance = self.class_(_mapper_nohistory=True)
             imap[identitykey] = instance
             isnew = True
@@ -877,9 +896,9 @@ class Mapper(object):
         
         # call further mapper properties on the row, to pull further 
         # instances from the row and possibly populate this item.
-        if self.extension.populate_instance(self, instance, row, identitykey, imap, isnew):
+        if self.extension.populate_instance(self, instance, row, identitykey, imap, isnew) is EXT_PASS:
             self.populate_instance(instance, row, identitykey, imap, isnew)
-        if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing):
+        if self.extension.append_result(self, row, imap, result, instance, isnew, populate_existing=populate_existing) is EXT_PASS:
             if result is not None:
                 result.append_nohistory(instance)
         return instance
@@ -966,6 +985,19 @@ class ExtensionOption(MapperOption):
 class MapperExtension(object):
     def __init__(self):
         self.next = None
+    def chain(self, ext):
+        self.next = ext
+        return self    
+    def select_by(self, mapper, *args, **kwargs):
+        if self.next is None:
+            return EXT_PASS
+        else:
+            return self.next.select_by(mapper, *args, **kwargs)
+    def select(self, mapper, *args, **kwargs):
+        if self.next is None:
+            return EXT_PASS
+        else:
+            return self.next.select(mapper, *args, **kwargs)
     def create_instance(self, mapper, row, imap, class_):
         """called when a new object instance is about to be created from a row.  
         the method can choose to create the instance itself, or it can return 
@@ -981,7 +1013,7 @@ class MapperExtension(object):
         class_ - the class we are mapping.
         """
         if self.next is None:
-            return None
+            return EXT_PASS
         else:
             return self.next.create_instance(mapper, row, imap, class_)
     def append_result(self, mapper, row, imap, result, instance, isnew, populate_existing=False):
@@ -1011,7 +1043,7 @@ class MapperExtension(object):
         identity map, i.e. were loaded by a previous select(), get their attributes overwritten
         """
         if self.next is None:
-            return True
+            return EXT_PASS
         else:
             return self.next.append_result(mapper, row, imap, result, instance, isnew, populate_existing)
     def populate_instance(self, mapper, instance, row, identitykey, imap, isnew):
@@ -1024,10 +1056,10 @@ class MapperExtension(object):
         
             def populate_instance(self, mapper, instance, row, identitykey, imap, isnew):
                 othermapper.populate_instance(instance, row, identitykey, imap, isnew, frommapper=mapper)
-                return False
+                return True
         """
         if self.next is None:
-            return True
+            return EXT_PASS
         else:
             return self.next.populate_instance(row, imap, result, instance, isnew)
     def before_insert(self, mapper, instance):
index 206905aa42610b0632f89cc082f0e37aed0c5d7c..fa14dc4deb2f84f64462a1da6866c761196ae1de 100644 (file)
@@ -382,9 +382,6 @@ class PropertyLoader(MapperProperty):
             else:
                 uowcommit.register_dependency(self.mapper, self.parent)
                 uowcommit.register_processor(self.mapper, self, self.parent, False)
-                # this dependency processor is used to locate "private" child objects
-                # during a "delete" operation, when the objectstore is being committed
-                # with only a partial list of objects
                 uowcommit.register_processor(self.mapper, self, self.parent, True)
         else:
             raise AssertionError(" no foreign key ?")
@@ -616,7 +613,7 @@ class LazyLoader(PropertyLoader):
                     order_by = self.secondary.default_order_by()
                 else:
                     order_by = False
-                result = list(self.mapper.select(self.lazywhere, order_by=order_by, params=params))
+                result = self.mapper.select_whereclause(self.lazywhere, order_by=order_by, params=params)
             else:
                 result = []
             if self.uselist:
index 8e780eeef4739370218cfc5e4518a2f5e378bf25..18e2ac0538d552afbf676f5c3ebb9c59ba679177 100644 (file)
@@ -1,75 +1,5 @@
 import sqlalchemy.sql as sql
 
-class SelectResults(object):
-    def __init__(self, mapper, clause=None, ops={}):
-        self._mapper = mapper
-        self._clause = clause
-        self._ops = {}
-        self._ops.update(ops)
-
-    def count(self):
-        return self._mapper.count(self._clause)
-    
-    def min(self, col):
-        return sql.select([sql.func.min(col)], self._clause, **self._ops).scalar()
-
-    def max(self, col):
-        return sql.select([sql.func.max(col)], self._clause, **self._ops).scalar()
-
-    def sum(self, col):
-        return sql.select([sql.func.sum(col)], self._clause, **self._ops).scalar()
-
-    def avg(self, col):
-        return sql.select([sql.func.avg(col)], self._clause, **self._ops).scalar()
-
-    def clone(self):
-        return SelectResults(self._mapper, self._clause, self._ops.copy())
-        
-    def filter(self, clause):
-        new = self.clone()
-        new._clause = sql.and_(self._clause, clause)
-        return new
-
-    def order_by(self, order_by):
-        new = self.clone()
-        new._ops['order_by'] = order_by
-        return new
-
-    def limit(self, limit):
-        return self[:limit]
-
-    def offset(self, offset):
-        return self[offset:]
-
-    def list(self):
-        return list(self)
-        
-    def __getitem__(self, item):
-        if isinstance(item, slice):
-            start = item.start
-            stop = item.stop
-            if (isinstance(start, int) and start < 0) or \
-               (isinstance(stop, int) and stop < 0):
-                return list(self)[item]
-            else:
-                res = self.clone()
-                if start is not None and stop is not None:
-                    res._ops.update(dict(offset=start, limit=stop-start))
-                elif start is None and stop is not None:
-                    res._ops.update(dict(limit=stop))
-                elif start is not None and stop is None:
-                    res._ops.update(dict(offset=start))
-                if item.step is not None:
-                    return list(res)[None:None:item.step]
-                else:
-                    return res
-        else:
-            return list(self[item:item+1])[0]
-    
-    def __iter__(self):
-        return iter(self._mapper.select_whereclause(self._clause, **self._ops))
-        
-        
 class TableFinder(sql.ClauseVisitor):
     """given a Clause, locates all the Tables within it into a list."""
     def __init__(self, table, check_columns=False):
diff --git a/lib/sqlalchemy/mods/__init__.py b/lib/sqlalchemy/mods/__init__.py
new file mode 100644 (file)
index 0000000..71b6756
--- /dev/null
@@ -0,0 +1,3 @@
+def install_mods(*mods):
+    for mod in mods:
+        mod.install_plugin()
\ No newline at end of file
diff --git a/lib/sqlalchemy/mods/selectresults.py b/lib/sqlalchemy/mods/selectresults.py
new file mode 100644 (file)
index 0000000..9613d8f
--- /dev/null
@@ -0,0 +1,84 @@
+import sqlalchemy.sql as sql
+
+import sqlalchemy.mapping as mapping
+
+def install_plugin():
+    mapping.extensions.append(SelectResultsExt)
+    
+class SelectResultsExt(mapping.MapperExtension):
+    def select_by(self, mapper, *args, **params):
+        return SelectResults(mapper, mapper._by_clause(*args, **params))
+    def select(self, mapper, arg=None, **kwargs):
+        if arg is not None and isinstance(arg, sql.Selectable):
+            return mapping.EXT_PASS
+        else:
+            return SelectResults(mapper, arg, ops=kwargs)
+        
+class SelectResults(object):
+    def __init__(self, mapper, clause=None, ops={}):
+        self._mapper = mapper
+        self._clause = clause
+        self._ops = {}
+        self._ops.update(ops)
+
+    def count(self):
+        return self._mapper.count(self._clause)
+    
+    def min(self, col):
+        return sql.select([sql.func.min(col)], self._clause, **self._ops).scalar()
+
+    def max(self, col):
+        return sql.select([sql.func.max(col)], self._clause, **self._ops).scalar()
+
+    def sum(self, col):
+        return sql.select([sql.func.sum(col)], self._clause, **self._ops).scalar()
+
+    def avg(self, col):
+        return sql.select([sql.func.avg(col)], self._clause, **self._ops).scalar()
+
+    def clone(self):
+        return SelectResults(self._mapper, self._clause, self._ops.copy())
+        
+    def filter(self, clause):
+        new = self.clone()
+        new._clause = sql.and_(self._clause, clause)
+        return new
+
+    def order_by(self, order_by):
+        new = self.clone()
+        new._ops['order_by'] = order_by
+        return new
+
+    def limit(self, limit):
+        return self[:limit]
+
+    def offset(self, offset):
+        return self[offset:]
+
+    def list(self):
+        return list(self)
+        
+    def __getitem__(self, item):
+        if isinstance(item, slice):
+            start = item.start
+            stop = item.stop
+            if (isinstance(start, int) and start < 0) or \
+               (isinstance(stop, int) and stop < 0):
+                return list(self)[item]
+            else:
+                res = self.clone()
+                if start is not None and stop is not None:
+                    res._ops.update(dict(offset=start, limit=stop-start))
+                elif start is None and stop is not None:
+                    res._ops.update(dict(limit=stop))
+                elif start is not None and stop is None:
+                    res._ops.update(dict(offset=start))
+                if item.step is not None:
+                    return list(res)[None:None:item.step]
+                else:
+                    return res
+        else:
+            return list(self[item:item+1])[0]
+    
+    def __iter__(self):
+        return iter(self._mapper.select_whereclause(self._clause, **self._ops))
index 6bfc3f3b894742b36fec01e41368b7a9f47dc97c..4a8edd09747aec3fb7559940d65269b1c760768e 100644 (file)
@@ -133,7 +133,7 @@ class MapperTest(MapperSuperTest):
         # object isnt refreshed yet, using dict to bypass trigger
         self.assert_(u.__dict__['user_name'] != 'jack')
         # do a select
-        m.select().list()
+        m.select()
         # test that it refreshed
         self.assert_(u.__dict__['user_name'] == 'jack')
         
@@ -255,7 +255,7 @@ class MapperTest(MapperSuperTest):
         m = mapper(User, users, properties = dict(
             addresses = relation(mapper(Address, addresses), lazy = True)
         ))
-        l = m.options(eagerload('addresses')).select().list()
+        l = m.options(eagerload('addresses')).select()
 
         def go():
             self.assert_result(l, User, *user_address_result)
@@ -266,7 +266,7 @@ class MapperTest(MapperSuperTest):
         m = mapper(User, users, properties = dict(
             addresses = relation(mapper(Address, addresses), lazy = False)
         ))
-        l = m.options(lazyload('addresses')).select().list()
+        l = m.options(lazyload('addresses')).select()
         def go():
             self.assert_result(l, User, *user_address_result)
         self.assert_sql_count(db, go, 3)
@@ -282,12 +282,12 @@ class MapperTest(MapperSuperTest):
             })
             
         m2 = m.options(eagerload('orders.items.keywords'))
-        u = m.select().list()
+        u = m.select()
         def go():
             print u[0].orders[1].items[0].keywords[1]
         self.assert_sql_count(db, go, 3)
         objectstore.clear()
-        u = m2.select().list()
+        u = m2.select()
         self.assert_sql_count(db, go, 2)
         
 class PropertyTest(MapperSuperTest):
@@ -368,7 +368,7 @@ class DeferredTest(MapperSuperTest):
         self.assert_(o.description is None)
         
         def go():
-            l = m.select().list()
+            l = m.select()
             o2 = l[2]
             print o2.description
 
@@ -397,7 +397,7 @@ class DeferredTest(MapperSuperTest):
         })
 
         def go():
-            l = m.select().list()
+            l = m.select()
             o2 = l[2]
             print o2.opened, o2.description, o2.userident
         self.assert_sql(db, go, [
@@ -410,7 +410,7 @@ class DeferredTest(MapperSuperTest):
         m = mapper(Order, orders)
         m2 = m.options(defer('user_id'))
         def go():
-            l = m2.select().list()
+            l = m2.select()
             print l[2].user_id
         self.assert_sql(db, go, [
             ("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}),
@@ -419,7 +419,7 @@ class DeferredTest(MapperSuperTest):
         objectstore.clear()
         m3 = m2.options(undefer('user_id'))
         def go():
-            l = m3.select().list()
+            l = m3.select()
             print l[3].user_id
         self.assert_sql(db, go, [
             ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}),
index cd01272b5c35c73595496bb641b0a614859fc642..170e526d96d7af107209f4f69a69fa51410422c9 100644 (file)
@@ -96,7 +96,7 @@ class ThreadProxyTest(PersistTest):
                     try:
                         trans  = objectstore.begin()
 
-                        all = User.select()[:].list()
+                        all = User.select()[:]
                         assert all == []
 
                         u = User()