]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
refactoring step 2. all deprecated functions now express their functionality
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 4 Jun 2007 23:50:22 +0000 (23:50 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 4 Jun 2007 23:50:22 +0000 (23:50 +0000)
in terms of generative behavior.  also the thing will run like crap right now until
the next refactor stage...stay tuned

lib/sqlalchemy/orm/query.py
test/orm/mapper.py
test/orm/query.py

index 31ba414d4a0c554e9ee6a904b0b488304c8f3f0c..c894f3767fd3c0c06fb560fec36976af67896e8c 100644 (file)
@@ -39,11 +39,15 @@ class Query(object):
         self._distinct = kwargs.pop('distinct', False)
         self._offset = kwargs.pop('offset', None)
         self._limit = kwargs.pop('limit', None)
+        self._statement = None
+        self._params = {}
         self._criterion = None
         self._col = None
         self._func = None
         self._joinpoint = self.mapper
         self._from_obj = [self.table]
+        self._populate_existing = False
+        self._version_check = False
 
         for opt in util.flatten_iterator(self.with_options):
             opt.process_query(self)
@@ -68,6 +72,10 @@ class Query(object):
         q._from_obj = list(self._from_obj)
         q._joinpoint = self._joinpoint
         q._criterion = self._criterion
+        q._statement = self._statement
+        q._params = self._params.copy()
+        q._populate_existing = self._populate_existing
+        q._version_check = self._version_check
         q._col = self._col
         q._func = self._func
         return q
@@ -143,6 +151,11 @@ class Query(object):
         else:
             primary_key = self.primary_key_columns
             s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **kwargs)
+        if params is None:
+            params = {}
+        else:
+            params = params.copy()
+        params.update(self._params)
         return self.session.scalar(self.mapper, s, params=params)
 
     def _with_lazy_criterion(cls, instance, prop, reverse=False):
@@ -284,6 +297,13 @@ class Query(object):
         q.lockmode = mode
         return q
     
+    def params(self, **kwargs):
+        """add values for bind parameters which may have been specified in filter()."""
+        
+        q = self._clone()
+        q._params.update(kwargs)
+        return q
+        
     def filter(self, criterion):
         """apply the given filtering criterion to the query and return the newly resulting ``Query``
         
@@ -565,13 +585,24 @@ class Query(object):
 
         return list(self)
     
+    def from_statement(self, statement):
+        if isinstance(statement, basestring):
+            statement = sql.text(statement)
+        q = self._clone()
+        q._statement = statement
+        return q
+        
     def first(self):
         """Return the first result of this ``Query``.
 
         This results in an execution of the underlying query.
         """
         if self._col is None or self._func is None: 
-            return self[0]
+            ret = list(self[0:1])
+            if len(ret) > 0:
+                return ret[0]
+            else:
+                return None
         else:
             return self._col_aggregate(self._col, self._func)
 
@@ -594,7 +625,13 @@ class Query(object):
             raise exceptions.InvalidRequestError('Multiple rows returned for one()')
     
     def __iter__(self):
-        return iter(self.select_whereclause())
+        statement = self.compile()
+        statement.use_labels = True
+        result = self.session.execute(self.mapper, statement, params=self._params)
+        try:
+            return iter(self.instances(result))
+        finally:
+            result.close()
 
 
     def instances(self, cursor, *mappers_or_columns, **kwargs):
@@ -624,6 +661,9 @@ class Query(object):
 
         session = self.session
 
+        kwargs.setdefault('populate_existing', self._populate_existing)
+        kwargs.setdefault('version_check', self._version_check)
+        
         context = SelectionContext(self.select_mapper, session, self.extension, with_options=self.with_options, **kwargs)
 
         process = []
@@ -685,8 +725,12 @@ class Query(object):
         for i, primary_key in enumerate(self.primary_key_columns):
             params[primary_key._label] = ident[i]
         try:
-            statement = self.compile(self._get_clause, lockmode=lockmode)
-            return self._select_statement(statement, params=params, populate_existing=reload, version_check=(lockmode is not None))[0]
+            q = self
+            if lockmode is not None:
+                q = q.with_lockmode(lockmode)
+            q = q.filter(self._get_clause)
+            q = q.params(**params)._select_context_options(populate_existing=reload, version_check=(lockmode is not None))
+            return q.first()
         except IndexError:
             return None
 
@@ -708,13 +752,14 @@ class Query(object):
 
         return (kwargs.get('limit') is not None or kwargs.get('offset') is not None or kwargs.get('distinct', False))
 
-    def compile(self, whereclause = None, **kwargs):
-        """Given a WHERE criterion, produce a ClauseElement-based
-        statement suitable for usage in the execute() method.
-        """
-
-        if self._criterion:
-            whereclause = sql.and_(self._criterion, whereclause)
+    def compile(self):
+        """compiles and returns a SQL statement based on the criterion and conditions within this Query."""
+        
+        if self._statement:
+            self._statement.use_labels = True
+            return self._statement
+        
+        whereclause = self._criterion
 
         if whereclause is not None and self.is_polymorphic:
             # adapt the given WHERECLAUSE to adjust instances of this query's mapped 
@@ -732,9 +777,7 @@ class Query(object):
         
         # get/create query context.  get the ultimate compile arguments
         # from there
-        context = kwargs.pop('query_context', None)
-        if context is None:
-            context = QueryContext(self, kwargs)
+        context = QueryContext(self)
         order_by = context.order_by
         group_by = context.group_by
         from_obj = context.from_obj
@@ -790,10 +833,12 @@ class Query(object):
             statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **context.select_args())
             if order_by:
                 statement.order_by(*util.to_list(order_by))
+                
             # for a DISTINCT query, you need the columns explicitly specified in order
             # to use it in "order_by".  ensure they are in the column criterion (particularly oid).
             # TODO: this should be done at the SQL level not the mapper level
-            if kwargs.get('distinct', False) and order_by:
+            # TODO: need test coverage for this 
+            if context.distinct and order_by:
                 [statement.append_column(c) for c in util.to_list(order_by)]
 
         context.statement = statement
@@ -829,59 +874,33 @@ class Query(object):
 
         return self.count(self.join_by(*args, **params))
 
-    def selectfirst(self, arg=None, **kwargs):
-        """DEPRECATED.  use query.filter(whereclause).first()"""
-
-        if isinstance(arg, sql.FromClause) and arg.supports_execution():
-            ret = self.select_statement(arg, **kwargs)
-        else:
-            kwargs['limit'] = 1
-            ret = self.select_whereclause(whereclause=arg, **kwargs)
-        if ret:
-            return ret[0]
-        else:
-            return None
-
-    def selectone(self, arg=None, **kwargs):
-        """DEPRECATED.  use query.filter(whereclause).one()"""
-        
-        if isinstance(arg, sql.FromClause) and arg.supports_execution():
-            ret = self.select_statement(arg, **kwargs)
-        else:
-            kwargs['limit'] = 2
-            ret = self.select_whereclause(whereclause=arg, **kwargs)
-        if len(ret) == 1:
-            return ret[0]
-        elif len(ret) == 0:
-            raise exceptions.InvalidRequestError('No rows returned for selectone_by')
-        else:
-            raise exceptions.InvalidRequestError('Multiple rows returned for selectone')
-
-    def select(self, arg=None, **kwargs):
-        """DEPRECATED.  use query.filter(whereclause).all(), or query.from_statement(statement).all()"""
-
-        ret = self.extension.select(self, arg=arg, **kwargs)
-        if ret is not mapper.EXT_PASS:
-            return ret
-        if isinstance(arg, sql.FromClause) and arg.supports_execution():
-            return self.select_statement(arg, **kwargs)
-        else:
-            return self.select_whereclause(whereclause=arg, **kwargs)
 
     def select_whereclause(self, whereclause=None, params=None, **kwargs):
         """DEPRECATED.  use query.filter(whereclause).all()"""
 
-        statement = self.compile(whereclause, **kwargs)
-        return self._select_statement(statement, params=params)
-
-    def execute(self, clauseelement, params=None, *args, **kwargs):
-        """DEPRECATED.  use query.select_from()"""
+        q = self.filter(whereclause)._legacy_select_kwargs(**kwargs)
+        if params is not None:
+            q = q.params(**params)
+        return list(q)
+        
+    def _legacy_select_kwargs(self, **kwargs):
+        q = self
+        if "order_by" in kwargs and kwargs['order_by']:
+            q = q.order_by(kwargs['order_by'])
+        if "group_by" in kwargs:
+            q = q.group_by(kwargs['group_by'])
+        if "from_obj" in kwargs:
+            q = q.select_from(kwargs['from_obj'])
+        if "lockmode" in kwargs:
+            q = q.with_lockmode(kwargs['lockmode'])
+        if "distinct" in kwargs:
+            q = q.distinct()
+        if "limit" in kwargs:
+            q = q.limit(kwargs['limit'])
+        if "offset" in kwargs:
+            q = q.offset(kwargs['offset'])
+        return q
 
-        result = self.session.execute(self.mapper, clauseelement, params=params)
-        try:
-            return self.instances(result, **kwargs)
-        finally:
-            result.close()
 
     def get_by(self, *args, **params):
         """DEPRECATED.  use query.filter(*args).filter_by(**params).first()"""
@@ -889,42 +908,76 @@ class Query(object):
         ret = self.extension.get_by(self, *args, **params)
         if ret is not mapper.EXT_PASS:
             return ret
-        x = self.select_whereclause(self.join_by(*args, **params), limit=1)
-        if x:
-            return x[0]
-        else:
-            return None
+
+        return self._legacy_filter_by(*args, **params).first()
 
     def select_by(self, *args, **params):
-        """DEPRECATED. use use query.filter(*args).filter_by(**params).list()."""
+        """DEPRECATED. use use query.filter(*args).filter_by(**params).all()."""
 
         ret = self.extension.select_by(self, *args, **params)
         if ret is not mapper.EXT_PASS:
             return ret
-        return self.select_whereclause(self.join_by(*args, **params))
+
+        return self._legacy_filter_by(*args, **params).list()
 
     def join_by(self, *args, **params):
         """DEPRECATED. use join() to construct joins based on attribute names."""
 
         return self._legacy_join_by(args, params, start=self._joinpoint)
 
+    def _build_select(self, arg=None, params=None, **kwargs):
+        if isinstance(arg, sql.FromClause) and arg.supports_execution():
+            return self.from_statement(arg)
+        else:
+            return self.filter(arg)._legacy_select_kwargs(**kwargs)
+
+    def selectfirst(self, arg=None, **kwargs):
+        """DEPRECATED.  use query.filter(whereclause).first()"""
+
+        return self._build_select(arg, **kwargs).first()
+
+    def selectone(self, arg=None, **kwargs):
+        """DEPRECATED.  use query.filter(whereclause).one()"""
+
+        return self._build_select(arg, **kwargs).one()
+
+    def select(self, arg=None, **kwargs):
+        """DEPRECATED.  use query.filter(whereclause).all(), or query.from_statement(statement).all()"""
+
+        ret = self.extension.select(self, arg=arg, **kwargs)
+        if ret is not mapper.EXT_PASS:
+            return ret
+        return self._build_select(arg, **kwargs).all()
+
+    def execute(self, clauseelement, params=None, *args, **kwargs):
+        """DEPRECATED.  use query.from_statement().all()"""
+
+        return self._select_statement(statement, params, **kwargs)
+
     def select_statement(self, statement, **params):
         """DEPRECATED.  Use query.from_statement(statement)"""
-
-        return self._select_statement(statement, params=params)
+        
+        return self._select_statement(statement, params)
 
     def select_text(self, text, **params):
         """DEPRECATED.  Use query.from_statement(statement)"""
 
-        t = sql.text(text)
-        return self.execute(t, params=params)
+        return self._select_statement(statement, params)
 
     def _select_statement(self, statement, params=None, **kwargs):
-        statement.use_labels = True
-        if params is None:
-            params = {}
-        return self.execute(statement, params=params, **kwargs)
-
+        q = self.from_statement(statement)
+        if params is not None:
+            q = q.params(**params)
+        q._select_context_options(**kwargs)
+        return list(q)
+
+    def _select_context_options(self, populate_existing=None, version_check=None):
+        if populate_existing is not None:
+            self._populate_existing = populate_existing
+        if version_check is not None:
+            self._version_check = version_check
+        return self
+        
     def join_to(self, key):
         """DEPRECATED. use join() to create joins based on property names."""
 
@@ -1001,18 +1054,12 @@ class Query(object):
     def selectfirst_by(self, *args, **params):
         """DEPRECATED. Use query.filter(*args).filter_by(**kwargs).first()"""
 
-        return self.get_by(*args, **params)
+        return self._legacy_filter_by(*args, **params).first()
 
     def selectone_by(self, *args, **params):
         """DEPRECATED. Use query.filter(*args).filter_by(**kwargs).one()"""
 
-        ret = self.select_whereclause(self.join_by(*args, **params), limit=2)
-        if len(ret) == 1:
-            return ret[0]
-        elif len(ret) == 0:
-            raise exceptions.InvalidRequestError('No rows returned for selectone_by')
-        else:
-            raise exceptions.InvalidRequestError('Multiple rows returned for selectone_by')
+        return self._legacy_filter_by(*args, **params).one()
 
 
 
@@ -1024,18 +1071,18 @@ class QueryContext(OperationContext):
     in a query construction.
     """
 
-    def __init__(self, query, kwargs):
+    def __init__(self, query):
         self.query = query
-        self.order_by = kwargs.pop('order_by', query._order_by)
-        self.group_by = kwargs.pop('group_by', query._group_by)
-        self.from_obj = kwargs.pop('from_obj', query._from_obj)
-        self.lockmode = kwargs.pop('lockmode', query.lockmode)
-        self.distinct = kwargs.pop('distinct', query._distinct)
-        self.limit = kwargs.pop('limit', query._limit)
-        self.offset = kwargs.pop('offset', query._offset)
+        self.order_by = query._order_by
+        self.group_by = query._group_by
+        self.from_obj = query._from_obj
+        self.lockmode = query.lockmode
+        self.distinct = query._distinct
+        self.limit = query._limit
+        self.offset = query._offset
         self.eager_loaders = util.Set([x for x in query.mapper._eager_loaders])
         self.statement = None
-        super(QueryContext, self).__init__(query.mapper, query.with_options, **kwargs)
+        super(QueryContext, self).__init__(query.mapper, query.with_options)
 
     def select_args(self):
         """Return a dictionary of attributes from this
index 558fa62809a05c4c871ccafcd068f30919bf6be7..e754945bb8f77ebd253c5e98536e30e62f7c3562 100644 (file)
@@ -1229,7 +1229,7 @@ class EagerTest(MapperSuperTest):
         m = mapper(User, users, properties = dict(
             addresses = relation(mapper(Address, addresses), lazy = False)
         ))
-        s = session.query(m).compile(and_(addresses.c.email_address == bindparam('emailad'), addresses.c.user_id==users.c.user_id))
+        s = session.query(m).filter(and_(addresses.c.email_address == bindparam('emailad'), addresses.c.user_id==users.c.user_id)).compile()
         c = s.compile()
         self.echo("\n" + str(c) + repr(c.get_params()))
         
index 8d3f5e67d438908c857fe50f01fcc29a271711c1..fbee5e88c86646e06092fa0cc18d0c7be2be802b 100644 (file)
@@ -189,6 +189,12 @@ class GetTest(QueryTest):
         class LocalFoo(Base):pass
         mapper(LocalFoo, table)
         assert create_session().query(LocalFoo).get(ustring) == LocalFoo(id=ustring, data=ustring)
+
+class SliceTest(QueryTest):
+    def test_first(self):
+        assert create_session().query(User).first() == User(id=7)
+        
+        assert create_session().query(User).filter(users.c.id==27).first() is None
         
 class FilterTest(QueryTest):
     def test_basic(self):