]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
mapper's querying facilities migrated to new query.Query() object, which can receive...
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Apr 2006 21:12:00 +0000 (21:12 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Apr 2006 21:12:00 +0000 (21:12 +0000)
session now propigates to the unitofwork UOWTransaction object, as well as mapper's save_obj/delete_obj via the UOWTransaction it receives. UOWTransaction explicitly calls the Session for the engine corresponding to each Mapper in the flush operation, although the Session does not yet affect the choice of engines used, and mapper save/delete is still using the Table's implicit SQLEngine.
changed internal unitofwork commit() method to be called flush().
removed all references to 'engine' from mapper module, including adding insert/update specific SQLEngine methods such as last_inserted_ids, last_inserted_params, etc. to the returned ResultProxy so that Mapper need not know which SQLEngine was used for the execute.
changes to unit tests, SelectResults to support the new Query object.

doc/build/content/docstrings.myt
lib/sqlalchemy/__init__.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/mapping/objectstore.py
lib/sqlalchemy/mapping/properties.py
lib/sqlalchemy/mapping/query.py [new file with mode: 0644]
lib/sqlalchemy/mapping/unitofwork.py
lib/sqlalchemy/mods/selectresults.py
test/mapper.py
test/objectstore.py

index f84c58806a27076e2333adc1891256f6eaf85a7e..1e157b6aae545b6c255f48b6f4b855ee39705e73 100644 (file)
@@ -17,6 +17,7 @@
 <& pydoc.myt:obj_doc, obj=sql, classes=[sql.ClauseParameters, sql.Compiled, sql.ClauseElement, sql.TableClause, sql.ColumnClause] &>
 <& pydoc.myt:obj_doc, obj=pool, classes=[pool.DBProxy, pool.Pool, pool.QueuePool, pool.SingletonThreadPool] &>
 <& pydoc.myt:obj_doc, obj=mapping, classes=[mapping.Mapper, mapping.MapperExtension] &>
+<& pydoc.myt:obj_doc, obj=mapping.query, classes=[mapping.query.Query] &>
 <& pydoc.myt:obj_doc, obj=mapping.objectstore, classes=[mapping.objectstore.Session, mapping.objectstore.Session.SessionTrans] &>
 <& pydoc.myt:obj_doc, obj=exceptions &>
 <& pydoc.myt:obj_doc, obj=proxy &>
index bbb57955badb086d40b5cfab70163d021f1c920b..94a0fcb6dffd8af90ae478c6fa26b0649f29eb01 100644 (file)
@@ -9,8 +9,9 @@ from types import *
 from sql import *
 from schema import *
 from exceptions import *
-import mapping as mapperlib
-from mapping import *
+import sqlalchemy.sql
+import sqlalchemy.mapping as mapping
+from sqlalchemy.mapping import *
 import sqlalchemy.schema
 import sqlalchemy.ext.proxy
 sqlalchemy.schema.default_engine = sqlalchemy.ext.proxy.ProxyEngine()
index 97c71076234462febbd00e8e6c3b42f9c4afd086..9572a4310db06c50afd580d0c34a5d977962a4d6 100644 (file)
@@ -817,6 +817,17 @@ class ResultProxy:
                 raise StopIteration
             else:
                 yield row
+     
+    def last_inserted_ids(self):
+        return self.engine.last_inserted_ids()
+    def last_updated_params(self):
+        return self.engine.last_updated_params()
+    def last_inserted_params(self):
+        return self.engine.last_inserted_params()
+    def lastrow_has_defaults(self):
+        return self.engine.lastrow_has_defaults()
+    def supports_sane_rowcount(self):
+        return self.engine.supports_sane_rowcount()
         
     def fetchall(self):
         """fetches all rows, just like DBAPI cursor.fetchall()."""
index e82aaeb12e7bedd21db4350f2eca79253e666819..7d6f9890cc69ff62b4b4ed330a765880f938d537 100644 (file)
@@ -7,11 +7,11 @@
 
 import sqlalchemy.sql as sql
 import sqlalchemy.schema as schema
-import sqlalchemy.engine as engine
 import sqlalchemy.util as util
 import util as mapperutil
 import sync
 from sqlalchemy.exceptions import *
+import query
 import objectstore
 import sys
 import weakref
@@ -205,10 +205,6 @@ class Mapper(object):
             proplist = self.columntoproperty.setdefault(column.original, [])
             proplist.append(prop)
 
-        self._get_clause = sql.and_()
-        for primary_key in self.pks_by_table[self.table]:
-            self._get_clause.clauses.append(primary_key == sql.bindparam("pk_"+primary_key.key))
-
         if not mapper_registry.has_key(self.class_key) or self.is_primary or (inherits is not None and inherits._is_primary_mapper()):
             objectstore.global_attributes.reset_class_managed(self.class_)
             self._init_class()
@@ -229,9 +225,75 @@ class Mapper(object):
         #print "mapper %s, columntoproperty:" % (self.class_.__name__)
         #for key, value in self.columntoproperty.iteritems():
         #    print key.table.name, key.key, [(v.key, v) for v in value]
-            
-    engines = property(lambda s: [t.engine for t in s.tables])
 
+    def _get_query(self):
+        try:
+            if self._query.mapper is not self:
+                self._query = query.Query(self)
+            return self._query
+        except AttributeError:
+            self._query = query.Query(self)
+            return self._query
+    query = property(_get_query, doc=\
+        """returns an instance of sqlalchemy.mapping.query.Query, which implements all the query-constructing
+        methods such as get(), select(), select_by(), etc.  The default Query object uses the global thread-local
+        Session from the objectstore package.  To get a Query object for a specific Session, call the 
+        using(session) method.""")
+    
+    def get(self, *ident, **kwargs):
+        """calls get() on this mapper's default Query object."""
+        return self.query.get(*ident, **kwargs)
+        
+    def _get(self, key, ident=None, reload=False):
+        return self.query._get(key, ident=ident, reload=reload)
+        
+    def get_by(self, *args, **params):
+        """calls get_by() on this mapper's default Query object."""
+        return self.query.get_by(*args, **params)
+
+    def select_by(self, *args, **params):
+        """calls select_by() on this mapper's default Query object."""
+        return self.query.select_by(*args, **params)
+
+    def selectfirst_by(self, *args, **params):
+        """calls selectfirst_by() on this mapper's default Query object."""
+        return self.query.selectfirst_by(*args, **params)
+
+    def selectone_by(self, *args, **params):
+        """calls selectone_by() on this mapper's default Query object."""
+        return self.query.selectone_by(*args, **params)
+
+    def count_by(self, *args, **params):
+        """calls count_by() on this mapper's default Query object."""
+        return self.query.count_by(*args, **params)
+
+    def selectfirst(self, *args, **params):
+        """calls selectfirst() on this mapper's default Query object."""
+        return self.query.selectfirst(*args, **params)
+
+    def selectone(self, *args, **params):
+        """calls selectone() on this mapper's default Query object."""
+        return self.query.selectone(*args, **params)
+
+    def select(self, arg=None, **kwargs):
+        """calls select() on this mapper's default Query object."""
+        return self.query.select(arg=arg, **kwargs)
+
+    def select_whereclause(self, whereclause=None, params=None, **kwargs):
+        """calls select_whereclause() on this mapper's default Query object."""
+        return self.query.select_whereclause(whereclause=whereclause, params=params, **kwargs)
+
+    def count(self, whereclause=None, params=None, **kwargs):
+        """calls count() on this mapper's default Query object."""
+        return self.query.count(whereclause=whereclause, params=params, **kwargs)
+
+    def select_statement(self, statement, **params):
+        """calls select_statement() on this mapper's default Query object."""
+        return self.query.select_statement(statement, **params)
+
+    def select_text(self, text, **params):
+        return self.query.select_text(text, **params)
+            
     def add_property(self, key, prop):
         """adds an additional property to this mapper.  this is the same as if it were 
         specified within the 'properties' argument to the constructor.  if the named
@@ -293,12 +355,18 @@ class Mapper(object):
         mapper_registry[self.class_key] = self
         if self.entity_name is None:
             self.class_.c = self.c
+    
+    def has_eager(self):
+        """returns True if one of the properties attached to this Mapper is eager loading"""
+        return getattr(self, '_has_eager', False)
         
     def set_property(self, key, prop):
         self.props[key] = prop
         prop.init(key, self)
     
     def instances(self, cursor, *mappers, **kwargs):
+        """given a cursor (ResultProxy) from an SQLEngine, returns a list of object instances
+        corresponding to the rows in the cursor."""
         limit = kwargs.get('limit', None)
         offset = kwargs.get('offset', None)
         session = kwargs.get('session', None)
@@ -330,37 +398,6 @@ class Mapper(object):
         if mappers:
             result = [result] + otherresults
         return result
-            
-    def get(self, *ident, **kwargs):
-        """returns an instance of the object based on the given identifier, or None
-        if not found.  The *ident argument is a 
-        list of primary key columns in the order of the table def's primary key columns."""
-        key = objectstore.get_id_key(ident, self.class_, self.entity_name)
-        #print "key: " + repr(key) + " ident: " + repr(ident)
-        return self._get(key, ident, **kwargs)
-        
-    def _get(self, key, ident=None, reload=False, session=None):
-        if not reload and not self.always_refresh:
-            try:
-                if session is None:
-                    session = objectstore.get_session()
-                return session._get(key)
-            except KeyError:
-                pass
-            
-        if ident is None:
-            ident = key[1]
-        i = 0
-        params = {}
-        for primary_key in self.pks_by_table[self.table]:
-            params["pk_"+primary_key.key] = ident[i]
-            i += 1
-        try:
-            statement = self._compile(self._get_clause)
-            return self._select_statement(statement, params=params, populate_existing=reload, session=session)[0]
-        except IndexError:
-            return None
-
         
     def identity_key(self, *primary_key):
         """returns the instance key for the given identity value.  this is a global tracking object used by the objectstore, and is usually available off a mapped object as instance._instance_key."""
@@ -377,7 +414,7 @@ class Mapper(object):
     def compile(self, whereclause = None, **options):
         """works like select, except returns the SQL statement object without 
         compiling or executing it"""
-        return self._compile(whereclause, **options)
+        return self.query._compile(whereclause, **options)
 
     def copy(self, **kwargs):
         mapper = Mapper.__new__(Mapper)
@@ -387,22 +424,11 @@ class Mapper(object):
         return mapper
     
     def using(self, session):
-        """returns a proxying object to this mapper, which will execute methods on the mapper
-        within the context of the given session.  The session is placed as the "current" session
-        via the push_session/pop_session methods in the objectstore module."""
+        """returns a new Query object with the given Session."""
         if objectstore.get_session() is session:
-            return self
-        mapper = self
-        class Proxy(object):
-            def __getattr__(self, key):
-                def callit(*args, **kwargs):
-                    objectstore.push_session(session)
-                    try:
-                        return getattr(mapper, key)(*args, **kwargs)
-                    finally:
-                        objectstore.pop_session()
-                return callit
-        return Proxy()
+            return self.query
+        else:
+            return query.Query(self, session=session)
 
     def options(self, *options, **kwargs):
         """uses this mapper as a prototype for a new mapper with different behavior.
@@ -418,169 +444,12 @@ class Mapper(object):
             self._options[optkey] = mapper
             return mapper
 
-    def get_by(self, *args, **params):
-        """returns a single object instance based on the given key/value criterion. 
-        this is either the first value in the result list, or None if the list is 
-        empty.
-        
-        the keys are mapped to property or column names mapped by this mapper's Table, and the values
-        are coerced into a WHERE clause separated by AND operators.  If the local property/column
-        names dont contain the key, a search will be performed against this mapper's immediate
-        list of relations as well, forming the appropriate join conditions if a matching property
-        is located.
-        
-        e.g.   u = usermapper.get_by(user_name = 'fred')
-        """
-        x = self.select_whereclause(self._by_clause(*args, **params), limit=1)
-        if x:
-            return x[0]
-        else:
-            return None
-            
-    def select_by(self, *args, **params):
-        """returns an array of object instances based on the given clauses and key/value criterion. 
-        
-        *args is a list of zero or more ClauseElements which will be connected by AND operators.
-        **params is a set of zero or more key/value parameters which are converted into ClauseElements.
-        the keys are mapped to property or column names mapped by this mapper's Table, and the values
-        are coerced into a WHERE clause separated by AND operators.  If the local property/column
-        names dont contain the key, a search will be performed against this mapper's immediate
-        list of relations as well, forming the appropriate join conditions if a matching property
-        is located.
-        
-        e.g.   result = usermapper.select_by(user_name = 'fred')
-        """
-        ret = self.extension.select_by(self, *args, **params)
-        if ret is not EXT_PASS:
-            return ret
-        return self.select_whereclause(self._by_clause(*args, **params))
-    
-    def selectfirst_by(self, *args, **params):
-        """works like select_by(), but only returns the first result by itself, or None if no 
-        objects returned.  Synonymous with get_by()"""
-        return self.get_by(*args, **params)
-
-    def selectone_by(self, *args, **params):
-        """works like selectfirst_by(), but throws an error if not exactly one result was returned."""
-        ret = mapper.select_whereclause(self._by_clause(*args, **params), limit=2)
-        if len(ret) == 1:
-            return ret[0]
-        raise InvalidRequestError('Multiple rows returned for selectone_by')
-
-    def count_by(self, *args, **params):
-        """returns the count of instances based on the given clauses and key/value criterion.
-        The criterion is constructed in the same way as the select_by() method."""
-        return self.count(self._by_clause(*args, **params))
-        
-    def _by_clause(self, *args, **params):
-        clause = None
-        for arg in args:
-            if clause is None:
-                clause = arg
-            else:
-                clause &= arg
-        for key, value in params.iteritems():
-            if value is False:
-                continue
-            c = self._get_criterion(key, value)
-            if c is None:
-                raise InvalidRequestError("Cant find criterion for property '"+ key + "'")
-            if clause is None:
-                clause = c
-            else:                
-                clause &= c
-        return clause
-
-    def _get_criterion(self, key, value):
-        """used by select_by to match a key/value pair against
-        local properties, column names, or a matching property in this mapper's
-        list of relations."""
-        if self.props.has_key(key):
-            return self.props[key].columns[0] == value
-        elif self.table.c.has_key(key):
-            return self.table.c[key] == value
-        else:
-            for prop in self.props.values():
-                c = prop.get_criterion(key, value)
-                if c is not None:
-                    return c
-            else:
-                return None
-
     def __getattr__(self, key):
-        if (key.startswith('select_by_')):
-            key = key[10:]
-            def foo(arg):
-                return self.select_by(**{key:arg})
-            return foo
-        elif (key.startswith('get_by_')):
-            key = key[7:]
-            def foo(arg):
-                return self.get_by(**{key:arg})
-            return foo
+        if (key.startswith('select_by_') or key.startswith('get_by_')):
+            return getattr(self.query, key)
         else:
             raise AttributeError(key)
-        
-    def selectfirst(self, *args, **params):
-        """works like select(), but only returns the first result by itself, or None if no 
-        objects returned."""
-        params['limit'] = 1
-        ret = self.select_whereclause(*args, **params)
-        if ret:
-            return ret[0]
-        else:
-            return None
-            
-    def selectone(self, *args, **params):
-        """works like selectfirst(), but throws an error if not exactly one result was returned."""
-        ret = list(self.select(*args, **params)[0:2])
-        if len(ret) == 1:
-            return ret[0]
-        raise InvalidRequestError('Multiple rows returned for selectone')
             
-    def select(self, arg=None, **kwargs):
-        """selects instances of the object from the database.  
-        
-        arg can be any ClauseElement, which will form the criterion with which to
-        load the objects.
-        
-        For more advanced usage, arg can also be a Select statement object, which
-        will be executed and its resulting rowset used to build new object instances.  
-        in this case, the developer must insure that an adequate set of columns exists in the 
-        rowset with which to build new object instances."""
-
-        ret = self.extension.select(self, arg=arg, **kwargs)
-        if ret is not EXT_PASS:
-            return ret
-        elif arg is not None and isinstance(arg, sql.Selectable):
-            return self.select_statement(arg, **kwargs)
-        else:
-            return self.select_whereclause(whereclause=arg, **kwargs)
-
-    def select_whereclause(self, whereclause=None, params=None, session=None, **kwargs):
-        statement = self._compile(whereclause, **kwargs)
-        return self._select_statement(statement, params=params, session=session)
-
-    def count(self, whereclause=None, params=None, **kwargs):
-        s = self.table.count(whereclause)
-        if params is not None:
-            return s.scalar(**params)
-        else:
-            return s.scalar()
-
-    def select_statement(self, statement, **params):
-        return self._select_statement(statement, params=params)
-
-    def select_text(self, text, **params):
-        t = sql.text(text, engine=self.primarytable.engine)
-        return self.instances(t.execute(**params))
-
-    def _select_statement(self, statement, params=None, **kwargs):
-        statement.use_labels = True
-        if params is None:
-            params = {}
-        return self.instances(statement.execute(**params), **kwargs)
-
     def _getpropbycolumn(self, column, raiseerror=True):
         try:
             prop = self.columntoproperty[column.original]
@@ -604,7 +473,6 @@ class Mapper(object):
 
     def _setattrbycolumn(self, obj, column, value):
         self.columntoproperty[column.original][0].setattr(obj, value)
-
         
     def save_obj(self, objects, uow, postupdate=False):
         """called by a UnitOfWork object to save objects, which involves either an INSERT or
@@ -714,17 +582,17 @@ class Mapper(object):
                 for rec in update:
                     (obj, params) = rec
                     c = statement.execute(params)
-                    self._postfetch(table, obj, table.engine.last_updated_params())
+                    self._postfetch(table, obj, c, c.last_updated_params())
                     self.extension.after_update(self, obj)
                     rows += c.cursor.rowcount
-                if table.engine.supports_sane_rowcount() and rows != len(update):
+                if c.supports_sane_rowcount() and rows != len(update):
                     raise CommitError("ConcurrencyError - updated rowcount %d does not match number of objects updated %d" % (rows, len(update)))
             if len(insert):
                 statement = table.insert()
                 for rec in insert:
                     (obj, params) = rec
-                    statement.execute(**params)
-                    primary_key = table.engine.last_inserted_ids()
+                    c = statement.execute(**params)
+                    primary_key = c.last_inserted_ids()
                     if primary_key is not None:
                         i = 0
                         for col in self.pks_by_table[table]:
@@ -732,16 +600,16 @@ class Mapper(object):
                             if self._getattrbycolumn(obj, col) is None:
                                 self._setattrbycolumn(obj, col, primary_key[i])
                             i+=1
-                    self._postfetch(table, obj, table.engine.last_inserted_params())
+                    self._postfetch(table, obj, c, c.last_inserted_params())
                     if self._synchronizer is not None:
                         self._synchronizer.execute(obj, obj)
                     self.extension.after_insert(self, obj)
 
-    def _postfetch(self, table, obj, params):
-        """after an INSERT or UPDATE, asks the engine if PassiveDefaults fired off on the database side
+    def _postfetch(self, table, obj, resultproxy, params):
+        """after an INSERT or UPDATE, asks the returned result if PassiveDefaults fired off on the database side
         which need to be post-fetched, *or* if pre-exec defaults like ColumnDefaults were fired off
         and should be populated into the instance. this is only for non-primary key columns."""
-        if table.engine.lastrow_has_defaults():
+        if resultproxy.lastrow_has_defaults():
             clause = sql.and_()
             for p in self.pks_by_table[table]:
                 clause.clauses.append(p == self._getattrbycolumn(obj, p))
@@ -785,7 +653,7 @@ class Mapper(object):
                     clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col.key))
                 statement = table.delete(clause)
                 c = statement.execute(*delete)
-                if table.engine.supports_sane_rowcount() and c.rowcount != len(delete):
+                if c.supports_sane_rowcount() and c.rowcount != len(delete):
                     raise CommitError("ConcurrencyError - updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete)))
 
     def _has_pks(self, table):
@@ -811,52 +679,6 @@ class Mapper(object):
         for prop in self.props.values():
             prop.register_deleted(obj, uow)
     
-    def _should_nest(self, **kwargs):
-        """returns True if the given statement options indicate that we should "nest" the
-        generated query as a subquery inside of a larger eager-loading query.  this is used
-        with keywords like distinct, limit and offset and the mapper defines eager loads."""
-        return (
-            getattr(self, '_has_eager', False)
-            and (kwargs.has_key('limit') or kwargs.has_key('offset') or kwargs.get('distinct', False))
-        )
-        
-    def _compile(self, whereclause = None, **kwargs):
-        order_by = kwargs.pop('order_by', False)
-        if order_by is False:
-            order_by = self.order_by
-        if order_by is False:
-            if self.table.default_order_by() is not None:
-                order_by = self.table.default_order_by()
-
-        if self._should_nest(**kwargs):
-            s2 = sql.select(self.table.primary_key, whereclause, use_labels=True, from_obj=[self.table], **kwargs)
-#            raise "ok first thing", str(s2)
-            if not kwargs.get('distinct', False) and order_by:
-                s2.order_by(*util.to_list(order_by))
-            s3 = s2.alias('rowcount')
-            crit = []
-            for i in range(0, len(self.table.primary_key)):
-                crit.append(s3.primary_key[i] == self.table.primary_key[i])
-            statement = sql.select([], sql.and_(*crit), from_obj=[self.table], use_labels=True)
- #           raise "OK statement", str(statement)
-            if order_by:
-                statement.order_by(*util.to_list(order_by))
-        else:
-            statement = sql.select([], whereclause, from_obj=[self.table], use_labels=True, **kwargs)
-            if order_by:
-                statement.order_by(*util.to_list(order_by))
-            # for a DISTINCT query, you need the columns explicitly specified in order
-            # to use it in "order_by".  insure they are in the column criterion (particularly oid).
-            # TODO: this should be done at the SQL level not the mapper level
-            if kwargs.get('distinct', False) and order_by:
-                statement.append_column(*util.to_list(order_by))
-        # plugin point
-        
-            
-        # give all the attached properties a chance to modify the query
-        for key, value in self.props.iteritems():
-            value.setup(key, statement, **kwargs) 
-        return statement
         
     def _identity_key(self, row):
         return objectstore.get_row_key(row, self.class_, self.pks_by_table[self.table], self.entity_name)
@@ -1003,16 +825,18 @@ class MapperExtension(object):
     def chain(self, ext):
         self.next = ext
         return self    
-    def select_by(self, mapper, *args, **kwargs):
+    def select_by(self, query, *args, **kwargs):
+        """overrides the select_by method of the Query object"""
         if self.next is None:
             return EXT_PASS
         else:
-            return self.next.select_by(mapper, *args, **kwargs)
-    def select(self, mapper, *args, **kwargs):
+            return self.next.select_by(query, *args, **kwargs)
+    def select(self, query, *args, **kwargs):
+        """overrides the select method of the Query object"""
         if self.next is None:
             return EXT_PASS
         else:
-            return self.next.select(mapper, *args, **kwargs)
+            return self.next.select(query, *args, **kwargs)
     def create_instance(self, mapper, row, imap, class_):
         """called when a new object instance is about to be created from a row.  
         the method can choose to create the instance itself, or it can return 
index ee0470cde06ca4a5b76ef845e7d26c52da2d6a82..1491d39ac0e0e750d050bee41308236315169632 100644 (file)
@@ -118,6 +118,9 @@ class Session(object):
         self.uow = unitofwork.UnitOfWork(identity_map = self.uow.identity_map)
         return Session.SessionTrans(self, self.uow, True)
     
+    def engines(self, mapper):
+        return [t.engine for t in mapper.tables]
+        
     def _trans_commit(self, trans):
         if trans.uow is self.uow and trans.isactive:
             try:
@@ -133,7 +136,7 @@ class Session(object):
     def _commit_uow(self, *obj):
         self.was_pushed()
         try:
-            self.uow.commit(*obj)
+            self.uow.flush(self, *obj)
         finally:
             self.was_popped()
                         
@@ -147,7 +150,7 @@ class Session(object):
         # change begin/commit status
         if len(objects):
             self._commit_uow(*objects)
-            self.uow.commit(*objects)
+            self.uow.flush(self, *objects)
             return
         if self.parent_uow is None:
             self._commit_uow()
@@ -283,13 +286,13 @@ def import_instance(instance):
     return get_session().import_instance(instance)
 
 def mapper(*args, **params):
-    return sqlalchemy.mapperlib.mapper(*args, **params)
+    return sqlalchemy.mapping.mapper(*args, **params)
 
 def object_mapper(obj):
-    return sqlalchemy.mapperlib.object_mapper(obj)
+    return sqlalchemy.mapping.object_mapper(obj)
 
 def class_mapper(class_):
-    return sqlalchemy.mapperlib.class_mapper(class_)
+    return sqlalchemy.mapping.class_mapper(class_)
 
 global_attributes = unitofwork.global_attributes
 
index 592e4dc0ac5afb5ba392c8a90a1301c5a6d07e85..af504100cf73c294498c8fda56aa98e9946ca07e 100644 (file)
@@ -582,7 +582,7 @@ class LazyLoader(PropertyLoader):
         (self.lazywhere, self.lazybinds) = create_lazy_clause(self.parent.noninherited_table, self.primaryjoin, self.secondaryjoin, self.foreignkey)
         # determine if our "lazywhere" clause is the same as the mapper's
         # get() clause.  then we can just use mapper.get()
-        self.use_get = not self.uselist and self.mapper._get_clause.compare(self.lazywhere)
+        self.use_get = not self.uselist and self.mapper.query._get_clause.compare(self.lazywhere)
         
     def _set_class_attribute(self, class_, key):
         # establish a class-level lazy loader on our class
@@ -609,14 +609,14 @@ class LazyLoader(PropertyLoader):
                     ident = []
                     for primary_key in self.mapper.pks_by_table[self.mapper.table]:
                         ident.append(params[primary_key._label])
-                    return self.mapper.get(session=session, *ident)
+                    return self.mapper.using(session).get(*ident)
                 elif self.order_by is not False:
                     order_by = self.order_by
                 elif self.secondary is not None and self.secondary.default_order_by() is not None:
                     order_by = self.secondary.default_order_by()
                 else:
                     order_by = False
-                result = self.mapper.select_whereclause(self.lazywhere, order_by=order_by, params=params, session=session)
+                result = self.mapper.using(session).select_whereclause(self.lazywhere, order_by=order_by, params=params)
             else:
                 result = []
             if self.uselist:
diff --git a/lib/sqlalchemy/mapping/query.py b/lib/sqlalchemy/mapping/query.py
new file mode 100644 (file)
index 0000000..09c2b9b
--- /dev/null
@@ -0,0 +1,267 @@
+
+import objectstore
+import sqlalchemy.sql as sql
+import sqlalchemy.util as util
+import mapper
+
+class Query(object):
+    """encapsulates the object-fetching operations provided by Mappers."""
+    def __init__(self, mapper, **kwargs):
+        self.mapper = mapper
+        self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh)
+        self.order_by = kwargs.pop('order_by', self.mapper.order_by)
+        self._session = kwargs.pop('session', None)
+        if not hasattr(mapper, '_get_clause'):
+            _get_clause = sql.and_()
+            for primary_key in self.mapper.pks_by_table[self.table]:
+                _get_clause.clauses.append(primary_key == sql.bindparam("pk_"+primary_key.key))
+            self.mapper._get_clause = _get_clause
+        self._get_clause = self.mapper._get_clause
+    def _get_session(self):
+        if self._session is None:
+            return objectstore.get_session()
+        else:
+            return self._session
+    table = property(lambda s:s.mapper.table)
+    props = property(lambda s:s.mapper.props)
+    session = property(_get_session)
+    
+    def get(self, *ident, **kwargs):
+        """returns an instance of the object based on the given identifier, or None
+        if not found.  The *ident argument is a 
+        list of primary key columns in the order of the table def's primary key columns."""
+        key = self.mapper.identity_key(*ident)
+        #print "key: " + repr(key) + " ident: " + repr(ident)
+        return self._get(key, ident, **kwargs)
+
+    def get_by(self, *args, **params):
+        """returns a single object instance based on the given key/value criterion. 
+        this is either the first value in the result list, or None if the list is 
+        empty.
+
+        the keys are mapped to property or column names mapped by this mapper's Table, and the values
+        are coerced into a WHERE clause separated by AND operators.  If the local property/column
+        names dont contain the key, a search will be performed against this mapper's immediate
+        list of relations as well, forming the appropriate join conditions if a matching property
+        is located.
+
+        e.g.   u = usermapper.get_by(user_name = 'fred')
+        """
+        x = self.select_whereclause(self._by_clause(*args, **params), limit=1)
+        if x:
+            return x[0]
+        else:
+            return None
+
+    def select_by(self, *args, **params):
+        """returns an array of object instances based on the given clauses and key/value criterion. 
+
+        *args is a list of zero or more ClauseElements which will be connected by AND operators.
+        **params is a set of zero or more key/value parameters which are converted into ClauseElements.
+        the keys are mapped to property or column names mapped by this mapper's Table, and the values
+        are coerced into a WHERE clause separated by AND operators.  If the local property/column
+        names dont contain the key, a search will be performed against this mapper's immediate
+        list of relations as well, forming the appropriate join conditions if a matching property
+        is located.
+
+        e.g.   result = usermapper.select_by(user_name = 'fred')
+        """
+        ret = self.mapper.extension.select_by(self, *args, **params)
+        if ret is not mapper.EXT_PASS:
+            return ret
+        return self.select_whereclause(self._by_clause(*args, **params))
+
+    def selectfirst_by(self, *args, **params):
+        """works like select_by(), but only returns the first result by itself, or None if no 
+        objects returned.  Synonymous with get_by()"""
+        return self.get_by(*args, **params)
+
+    def selectone_by(self, *args, **params):
+        """works like selectfirst_by(), but throws an error if not exactly one result was returned."""
+        ret = self.select_whereclause(self._by_clause(*args, **params), limit=2)
+        if len(ret) == 1:
+            return ret[0]
+        raise InvalidRequestError('Multiple rows returned for selectone_by')
+
+    def count_by(self, *args, **params):
+        """returns the count of instances based on the given clauses and key/value criterion.
+        The criterion is constructed in the same way as the select_by() method."""
+        return self.count(self._by_clause(*args, **params))
+
+    def selectfirst(self, *args, **params):
+        """works like select(), but only returns the first result by itself, or None if no 
+        objects returned."""
+        params['limit'] = 1
+        ret = self.select_whereclause(*args, **params)
+        if ret:
+            return ret[0]
+        else:
+            return None
+
+    def selectone(self, *args, **params):
+        """works like selectfirst(), but throws an error if not exactly one result was returned."""
+        ret = list(self.select(*args, **params)[0:2])
+        if len(ret) == 1:
+            return ret[0]
+        raise InvalidRequestError('Multiple rows returned for selectone')
+
+    def select(self, arg=None, **kwargs):
+        """selects instances of the object from the database.  
+
+        arg can be any ClauseElement, which will form the criterion with which to
+        load the objects.
+
+        For more advanced usage, arg can also be a Select statement object, which
+        will be executed and its resulting rowset used to build new object instances.  
+        in this case, the developer must insure that an adequate set of columns exists in the 
+        rowset with which to build new object instances."""
+
+        ret = self.mapper.extension.select(self, arg=arg, **kwargs)
+        if ret is not mapper.EXT_PASS:
+            return ret
+        elif arg is not None and isinstance(arg, sql.Selectable):
+            return self.select_statement(arg, **kwargs)
+        else:
+            return self.select_whereclause(whereclause=arg, **kwargs)
+
+    def select_whereclause(self, whereclause=None, params=None, **kwargs):
+        statement = self._compile(whereclause, **kwargs)
+        return self._select_statement(statement, params=params)
+
+    def count(self, whereclause=None, params=None, **kwargs):
+        s = self.table.count(whereclause)
+        if params is not None:
+            return s.scalar(**params)
+        else:
+            return s.scalar()
+
+    def select_statement(self, statement, **params):
+        return self._select_statement(statement, params=params)
+
+    def select_text(self, text, **params):
+        t = sql.text(text, engine=self.mapper.primarytable.engine)
+        return self.instances(t.execute(**params))
+
+    def __getattr__(self, key):
+        if (key.startswith('select_by_')):
+            key = key[10:]
+            def foo(arg):
+                return self.select_by(**{key:arg})
+            return foo
+        elif (key.startswith('get_by_')):
+            key = key[7:]
+            def foo(arg):
+                return self.get_by(**{key:arg})
+            return foo
+        else:
+            raise AttributeError(key)
+
+    def instances(self, *args, **kwargs):
+        return self.mapper.instances(session=self.session, *args, **kwargs)
+        
+    def _by_clause(self, *args, **params):
+        clause = None
+        for arg in args:
+            if clause is None:
+                clause = arg
+            else:
+                clause &= arg
+        for key, value in params.iteritems():
+            if value is False:
+                continue
+            c = self._get_criterion(key, value)
+            if c is None:
+                raise InvalidRequestError("Cant find criterion for property '"+ key + "'")
+            if clause is None:
+                clause = c
+            else:                
+                clause &= c
+        return clause
+
+    def _get(self, key, ident=None, reload=False):
+        if not reload and not self.always_refresh:
+            try:
+                return self.session._get(key)
+            except KeyError:
+                pass
+
+        if ident is None:
+            ident = key[1]
+        i = 0
+        params = {}
+        for primary_key in self.mapper.pks_by_table[self.table]:
+            params["pk_"+primary_key.key] = ident[i]
+            i += 1
+        try:
+            statement = self._compile(self._get_clause)
+            return self._select_statement(statement, params=params, populate_existing=reload)[0]
+        except IndexError:
+            return None
+
+    def _select_statement(self, statement, params=None, **kwargs):
+        statement.use_labels = True
+        if params is None:
+            params = {}
+        return self.instances(statement.execute(**params), **kwargs)
+
+    def _should_nest(self, **kwargs):
+        """returns True if the given statement options indicate that we should "nest" the
+        generated query as a subquery inside of a larger eager-loading query.  this is used
+        with keywords like distinct, limit and offset and the mapper defines eager loads."""
+        return (
+            self.mapper.has_eager()
+            and (kwargs.has_key('limit') or kwargs.has_key('offset') or kwargs.get('distinct', False))
+        )
+
+    def _compile(self, whereclause = None, **kwargs):
+        order_by = kwargs.pop('order_by', False)
+        if order_by is False:
+            order_by = self.order_by
+        if order_by is False:
+            if self.table.default_order_by() is not None:
+                order_by = self.table.default_order_by()
+
+        if self._should_nest(**kwargs):
+            s2 = sql.select(self.table.primary_key, whereclause, use_labels=True, from_obj=[self.table], **kwargs)
+#            raise "ok first thing", str(s2)
+            if not kwargs.get('distinct', False) and order_by:
+                s2.order_by(*util.to_list(order_by))
+            s3 = s2.alias('rowcount')
+            crit = []
+            for i in range(0, len(self.table.primary_key)):
+                crit.append(s3.primary_key[i] == self.table.primary_key[i])
+            statement = sql.select([], sql.and_(*crit), from_obj=[self.table], use_labels=True)
+ #           raise "OK statement", str(statement)
+            if order_by:
+                statement.order_by(*util.to_list(order_by))
+        else:
+            statement = sql.select([], whereclause, from_obj=[self.table], use_labels=True, **kwargs)
+            if order_by:
+                statement.order_by(*util.to_list(order_by))
+            # for a DISTINCT query, you need the columns explicitly specified in order
+            # to use it in "order_by".  insure they are in the column criterion (particularly oid).
+            # TODO: this should be done at the SQL level not the mapper level
+            if kwargs.get('distinct', False) and order_by:
+                statement.append_column(*util.to_list(order_by))
+        # plugin point
+
+        # give all the attached properties a chance to modify the query
+        for key, value in self.mapper.props.iteritems():
+            value.setup(key, statement, **kwargs) 
+        return statement
+
+    def _get_criterion(self, key, value):
+        """used by select_by to match a key/value pair against
+        local properties, column names, or a matching property in this mapper's
+        list of relations."""
+        if self.props.has_key(key):
+            return self.props[key].columns[0] == value
+        elif self.table.c.has_key(key):
+            return self.table.c[key] == value
+        else:
+            for prop in self.props.values():
+                c = prop.get_criterion(key, value)
+                if c is not None:
+                    return c
+            else:
+                return None
index b08836a208d03898d7bc3e3560335a3722f856e8..3ef1d96aec85be3c7578dd949ec3f3f33aeb1437 100644 (file)
@@ -5,13 +5,13 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 """the internals for the Unit Of Work system.  includes hooks into the attributes package
-enabling the routing of change events to Unit Of Work objects, as well as the commit mechanism
+enabling the routing of change events to Unit Of Work objects, as well as the flush() mechanism
 which creates a dependency structure that executes change operations.  
 
 a Unit of Work is essentially a system of maintaining a graph of in-memory objects and their
 modified state.  Objects are maintained as unique against their primary key identity using
 an "identity map" pattern.  The Unit of Work then maintains lists of objects that are new, 
-dirty, or deleted and provides the capability to commit all those changes at once.
+dirty, or deleted and provides the capability to flush all those changes at once.
 """
 
 from sqlalchemy import attributes
@@ -23,7 +23,7 @@ import weakref
 import topological
 from sets import *
 
-# a global indicating if all commit() operations should have their plan
+# a global indicating if all flush() operations should have their plan
 # printed to standard output.  also can be affected by creating an engine
 # with the "echo_uow=True" keyword argument.
 LOG = False
@@ -73,7 +73,7 @@ class UOWAttributeManager(attributes.AttributeManager):
         return UOWListElement(obj, key, list_, **kwargs)
         
 class UnitOfWork(object):
-    """main UOW object which stores lists of dirty/new/deleted objects, as well as 'modified_lists' for list attributes.  provides top-level "commit" functionality as well as the transaction boundaries with the SQLEngine(s) involved in a write operation."""
+    """main UOW object which stores lists of dirty/new/deleted objects, as well as 'modified_lists' for list attributes.  provides top-level "flush" functionality as well as the transaction boundaries with the SQLEngine(s) involved in a write operation."""
     def __init__(self, identity_map=None):
         if identity_map is not None:
             self.identity_map = identity_map
@@ -141,7 +141,7 @@ class UnitOfWork(object):
         self.attributes.remove(obj)
 
     def _validate_obj(self, obj):
-        """validates that dirty/delete/commit operations can occur upon the given object, by checking
+        """validates that dirty/delete/flush operations can occur upon the given object, by checking
         if it has an instance key and that the instance key is present in the identity map."""
         if hasattr(obj, '_instance_key') and not self.identity_map.has_key(obj._instance_key):
             raise InvalidRequestError("Detected a mapped object not present in the current thread's Identity Map: '%s'.  Use objectstore.import_instance() to place deserialized instances or instances from other threads" % repr(obj._instance_key))
@@ -203,8 +203,8 @@ class UnitOfWork(object):
         except KeyError:
             pass
             
-    def commit(self, *objects):
-        commit_context = UOWTransaction(self)
+    def flush(self, session, *objects):
+        flush_context = UOWTransaction(self, session)
 
         if len(objects):
             objset = util.HashSet(iter=objects)
@@ -216,29 +216,29 @@ class UnitOfWork(object):
                 continue
             if self.deleted.contains(obj):
                 continue
-            commit_context.register_object(obj)
+            flush_context.register_object(obj)
         for item in self.modified_lists:
             obj = item.obj
             if objset is not None and not objset.contains(obj):
                 continue
             if self.deleted.contains(obj):
                 continue
-            commit_context.register_object(obj, listonly = True)
-            commit_context.register_saved_history(item)
+            flush_context.register_object(obj, listonly = True)
+            flush_context.register_saved_history(item)
 
 #            for o in item.added_items() + item.deleted_items():
 #                if self.deleted.contains(o):
 #                    continue
-#                commit_context.register_object(o, listonly=True)
+#                flush_context.register_object(o, listonly=True)
                      
         for obj in self.deleted:
             if objset is not None and not objset.contains(obj):
                 continue
-            commit_context.register_object(obj, isdelete=True)
+            flush_context.register_object(obj, isdelete=True)
                 
         engines = util.HashSet()
-        for mapper in commit_context.mappers:
-            for e in mapper.engines:
+        for mapper in flush_context.mappers:
+            for e in session.engines(mapper):
                 engines.append(e)
         
         echo_commit = False        
@@ -246,7 +246,7 @@ class UnitOfWork(object):
             echo_commit = echo_commit or e.echo_uow
             e.begin()
         try:
-            commit_context.execute(echo=echo_commit)
+            flush_context.execute(echo=echo_commit)
         except:
             for e in engines:
                 e.rollback()
@@ -254,7 +254,7 @@ class UnitOfWork(object):
         for e in engines:
             e.commit()
             
-        commit_context.post_exec()
+        flush_context.post_exec()
         
 
     def rollback_object(self, obj):
@@ -271,10 +271,10 @@ class UnitOfWork(object):
             
 class UOWTransaction(object):
     """handles the details of organizing and executing transaction tasks 
-    during a UnitOfWork object's commit() operation."""
-    def __init__(self, uow):
+    during a UnitOfWork object's flush() operation."""
+    def __init__(self, uow, session):
         self.uow = uow
-
+        self.session = session
         #  unique list of all the mappers we come across
         self.mappers = util.HashSet()
         self.dependencies = {}
@@ -379,8 +379,8 @@ class UOWTransaction(object):
                 print "\nExecute complete (no post-exec changes)\n"
             
     def post_exec(self):
-        """after an execute/commit is completed, all of the objects and lists that have
-        been committed are updated in the parent UnitOfWork object to mark them as clean."""
+        """after an execute/flush is completed, all of the objects and lists that have
+        been flushed are updated in the parent UnitOfWork object to mark them as clean."""
         
         for task in self.tasks.values():
             for elem in task.objects.values():
@@ -396,7 +396,7 @@ class UOWTransaction(object):
             except KeyError:
                 pass
 
-    # this assertion only applies to a full commit(), not a
+    # this assertion only applies to a full flush(), not a
     # partial one
         #if len(self.uow.new) > 0 or len(self.uow.dirty) >0 or len(self.uow.modified_lists) > 0:
         #    raise "assertion failed"
index b4f16c41ced0f4d09fe44e802c24382dd7d62353..5528c7bf660b4fb817ce1d79acb814ac8f6d134f 100644 (file)
@@ -6,25 +6,25 @@ def install_plugin():
     mapping.global_extensions.append(SelectResultsExt)
     
 class SelectResultsExt(mapping.MapperExtension):
-    def select_by(self, mapper, *args, **params):
-        return SelectResults(mapper, mapper._by_clause(*args, **params))
-    def select(self, mapper, arg=None, **kwargs):
+    def select_by(self, query, *args, **params):
+        return SelectResults(query, query._by_clause(*args, **params))
+    def select(self, query, arg=None, **kwargs):
         if arg is not None and isinstance(arg, sql.Selectable):
             return mapping.EXT_PASS
         else:
-            return SelectResults(mapper, arg, ops=kwargs)
+            return SelectResults(query, arg, ops=kwargs)
 
 MapperExtension = SelectResultsExt
         
 class SelectResults(object):
-    def __init__(self, mapper, clause=None, ops={}):
-        self._mapper = mapper
+    def __init__(self, query, clause=None, ops={}):
+        self._query = query
         self._clause = clause
         self._ops = {}
         self._ops.update(ops)
 
     def count(self):
-        return self._mapper.count(self._clause)
+        return self._query.count(self._clause)
     
     def min(self, col):
         return sql.select([sql.func.min(col)], self._clause, **self._ops).scalar()
@@ -39,7 +39,7 @@ class SelectResults(object):
         return sql.select([sql.func.avg(col)], self._clause, **self._ops).scalar()
 
     def clone(self):
-        return SelectResults(self._mapper, self._clause, self._ops.copy())
+        return SelectResults(self._query, self._clause, self._ops.copy())
         
     def filter(self, clause):
         new = self.clone()
@@ -83,4 +83,4 @@ class SelectResults(object):
             return list(self[item:item+1])[0]
     
     def __iter__(self):
-        return iter(self._mapper.select_whereclause(self._clause, **self._ops))
+        return iter(self._query.select_whereclause(self._clause, **self._ops))
index 41c29cfdfcb0d9f4661f26faede7ae7106682f55..b3c2a2ab095a6c5244b15555d9d883c3587a9925 100644 (file)
@@ -127,7 +127,7 @@ class MapperTest(MapperSuperTest):
     def testsessionpropigation(self):
         sess = objectstore.Session()
         m = mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=True)})
-        u = m.get(7, session=sess)
+        u = m.using(sess).get(7)
         assert objectstore.get_session(u) is sess
         assert objectstore.get_session(u.addresses[0]) is sess
         
index 9166c5dea8ac064d6c9aefaf1fd9545068bce29b..4e61d08f0ce6c7c08e2b9516f488c11922b04fe2 100644 (file)
@@ -172,6 +172,7 @@ class SessionTest(AssertMixin):
 
 class VersioningTest(AssertMixin):
     def setUpAll(self):
+        objectstore.clear()
         global version_table
         version_table = Table('version_test', db,
         Column('id', Integer, primary_key=True),
@@ -226,6 +227,7 @@ class VersioningTest(AssertMixin):
         
 class UnicodeTest(AssertMixin):
     def setUpAll(self):
+        objectstore.clear()
         global uni_table
         uni_table = Table('uni_test', db,
             Column('id',  Integer, primary_key=True),