]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fixed the test for FalseDiscriminator to use Boolean for picky postgresql
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Jul 2009 20:27:33 +0000 (20:27 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Jul 2009 20:27:33 +0000 (20:27 +0000)
- added Query.enable_assertions(False) as a mediocre solution for [ticket:1424].
updated the recipe at http://www.sqlalchemy.org/trac/wiki/UsageRecipes/PreFilteredQuery to
reflect.
- moved most default Query state to be class level variables to start.  the dicts could
go as well but being overly careful to not place mutables there for the moment.
- a visit by the "dunder-private method names aren't cool" police
- continued undisciplined pep-8ness

CHANGES
lib/sqlalchemy/orm/query.py
test/orm/inheritance/test_basic.py
test/orm/test_query.py

diff --git a/CHANGES b/CHANGES
index 3ad73f5381a64e75d68829e23634260169b22c0c..eebdad0e9aab5b15dcdb33bafb6e8af50ee9b84b 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -23,6 +23,13 @@ CHANGES
     - Using False or 0 as a polymorphic discriminator now
       works on the base class as well as a subclass.
       [ticket:1440]
+
+    - Added enable_assertions(False) to Query which disables
+      the usual assertions for expected state - used
+      by Query subclasses to engineer custom state.
+      [ticket:1424].  See
+      http://www.sqlalchemy.org/trac/wiki/UsageRecipes/PreFilteredQuery
+      for an example.
       
 - sql
     - Fixed a bug in extract() introduced in 0.5.4 whereby
index 78e35421880c1c76e794e1d124c588317e18fbb2..1574545eb7f1ae0591e278cf90a82d4e8a6dd97d 100644 (file)
@@ -54,40 +54,41 @@ def _generative(*assertions):
     return generate
 
 class Query(object):
-    """Encapsulates the object-fetching operations provided by Mappers."""
-
+    """ORM-level SQL construction object."""
+    
+    _enable_eagerloads = True
+    _enable_assertions = True
+    _with_labels = False
+    _criterion = None
+    _yield_per = None
+    _lockmode = None
+    _order_by = False
+    _group_by = False
+    _having = None
+    _distinct = False
+    _offset = None
+    _limit = None
+    _statement = None
+    _joinpoint = None
+    _correlate = frozenset()
+    _populate_existing = False
+    _version_check = False
+    _autoflush = True
+    _current_path = ()
+    _only_load_props = None
+    _refresh_state = None
+    _from_obj = ()
+    _filter_aliases = None
+    _from_obj_alias = None
+    _currenttables = frozenset()
+    
     def __init__(self, entities, session=None):
         self.session = session
 
         self._with_options = []
-        self._lockmode = None
-        self._order_by = False
-        self._group_by = False
-        self._distinct = False
-        self._offset = None
-        self._limit = None
-        self._statement = None
         self._params = {}
-        self._yield_per = None
-        self._criterion = None
-        self._correlate = set()
-        self._joinpoint = None
-        self._with_labels = False
-        self._enable_eagerloads = True
-        self.__joinable_tables = None
-        self._having = None
-        self._populate_existing = False
-        self._version_check = False
-        self._autoflush = True
         self._attributes = {}
-        self._current_path = ()
-        self._only_load_props = None
-        self._refresh_state = None
-        self._from_obj = ()
         self._polymorphic_adapters = {}
-        self._filter_aliases = None
-        self._from_obj_alias = None
-        self.__currenttables = set()
         self._set_entities(entities)
 
     def _set_entities(self, entities, entity_wrapper=None):
@@ -97,9 +98,9 @@ class Query(object):
         for ent in util.to_list(entities):
             entity_wrapper(self, ent)
 
-        self.__setup_aliasizers(self._entities)
+        self._setup_aliasizers(self._entities)
 
-    def __setup_aliasizers(self, entities):
+    def _setup_aliasizers(self, entities):
         if hasattr(self, '_mapper_adapter_map'):
             # usually safe to share a single map, but copying to prevent
             # subtle leaks if end-user is reusing base query with arbitrary
@@ -114,7 +115,8 @@ class Query(object):
                     mapper, selectable, is_aliased_class = _entity_info(entity)
                     if not is_aliased_class and mapper.with_polymorphic:
                         with_polymorphic = mapper._with_polymorphic_mappers
-                        self.__mapper_loads_polymorphically_with(mapper, sql_util.ColumnAdapter(selectable, mapper._equivalent_columns))
+                        self.__mapper_loads_polymorphically_with(mapper, 
+                                sql_util.ColumnAdapter(selectable, mapper._equivalent_columns))
                         adapter = None
                     elif is_aliased_class:
                         adapter = sql_util.ColumnAdapter(selectable, mapper._equivalent_columns)
@@ -131,7 +133,7 @@ class Query(object):
             for m in m2.iterate_to_root():
                 self._polymorphic_adapters[m.mapped_table] = self._polymorphic_adapters[m.local_table] = adapter
 
-    def __set_select_from(self, from_obj):
+    def _set_select_from(self, from_obj):
         if isinstance(from_obj, expression._SelectBaseMixin):
             from_obj = from_obj.alias()
 
@@ -142,7 +144,8 @@ class Query(object):
             self._from_obj_alias = sql_util.ColumnAdapter(from_obj, equivs)
 
     def _get_polymorphic_adapter(self, entity, selectable):
-        self.__mapper_loads_polymorphically_with(entity.mapper, sql_util.ColumnAdapter(selectable, entity.mapper._equivalent_columns))
+        self.__mapper_loads_polymorphically_with(entity.mapper, 
+                                sql_util.ColumnAdapter(selectable, entity.mapper._equivalent_columns))
 
     def _reset_polymorphic_adapter(self, mapper):
         for m2 in mapper._with_polymorphic_mappers:
@@ -151,7 +154,7 @@ class Query(object):
                 self._polymorphic_adapters.pop(m.mapped_table, None)
                 self._polymorphic_adapters.pop(m.local_table, None)
 
-    def __reset_joinpoint(self):
+    def _reset_joinpoint(self):
         self._joinpoint = None
         self._filter_aliases = None
 
@@ -210,9 +213,17 @@ class Query(object):
             return clause
 
         if getattr(self, '_disable_orm_filtering', not orm_only):
-            return visitors.replacement_traverse(clause, {'column_collections':False}, self.__replace_element(adapters))
+            return visitors.replacement_traverse(
+                                clause, 
+                                {'column_collections':False}, 
+                                self.__replace_element(adapters)
+                            )
         else:
-            return visitors.replacement_traverse(clause, {'column_collections':False}, self.__replace_orm_element(adapters))
+            return visitors.replacement_traverse(
+                                clause, 
+                                {'column_collections':False}, 
+                                self.__replace_orm_element(adapters)
+                            )
 
     def _entity_zero(self):
         return self._entities[0]
@@ -243,12 +254,16 @@ class Query(object):
 
     def _only_mapper_zero(self, rationale=None):
         if len(self._entities) > 1:
-            raise sa_exc.InvalidRequestError(rationale or "This operation requires a Query against a single mapper.")
+            raise sa_exc.InvalidRequestError(
+                    rationale or "This operation requires a Query against a single mapper."
+                )
         return self._mapper_zero()
 
     def _only_entity_zero(self, rationale=None):
         if len(self._entities) > 1:
-            raise sa_exc.InvalidRequestError(rationale or "This operation requires a Query against a single mapper.")
+            raise sa_exc.InvalidRequestError(
+                    rationale or "This operation requires a Query against a single mapper."
+                )
         return self._entity_zero()
 
     def _generate_mapper_zero(self):
@@ -264,7 +279,9 @@ class Query(object):
             equivs.update(ent.mapper._equivalent_columns)
         return equivs
 
-    def __no_criterion_condition(self, meth):
+    def _no_criterion_condition(self, meth):
+        if not self._enable_assertions:
+            return
         if self._criterion or self._statement or self._from_obj or \
                 self._limit is not None or self._offset is not None or \
                 self._group_by:
@@ -273,28 +290,36 @@ class Query(object):
         self._from_obj = ()
         self._statement = self._criterion = None
         self._order_by = self._group_by = self._distinct = False
-        self.__joined_tables = {}
 
-    def __no_clauseelement_condition(self, meth):
+    def _no_clauseelement_condition(self, meth):
+        if not self._enable_assertions:
+            return
         if self._order_by:
             raise sa_exc.InvalidRequestError("Query.%s() being called on a Query with existing criterion. " % meth)
-        self.__no_criterion_condition(meth)
+        self._no_criterion_condition(meth)
 
-    def __no_statement_condition(self, meth):
+    def _no_statement_condition(self, meth):
+        if not self._enable_assertions:
+            return
         if self._statement:
             raise sa_exc.InvalidRequestError(
                 ("Query.%s() being called on a Query with an existing full "
                  "statement - can't apply criterion.") % meth)
 
-    def __no_limit_offset(self, meth):
+    def _no_limit_offset(self, meth):
+        if not self._enable_assertions:
+            return
         if self._limit is not None or self._offset is not None:
-            # TODO: do we want from_self() to be implicit here ?  i vote explicit for the time being
-            raise sa_exc.InvalidRequestError("Query.%s() being called on a Query which already has LIMIT or OFFSET applied. "
-            "To modify the row-limited results of a Query, call from_self() first.  Otherwise, call %s() before limit() or offset() are applied." % (meth, meth)
+            raise sa_exc.InvalidRequestError(
+                "Query.%s() being called on a Query which already has LIMIT or OFFSET applied. "
+                "To modify the row-limited results of a Query, call from_self() first.  "
+                "Otherwise, call %s() before limit() or offset() are applied." % (meth, meth)
             )
 
-
-    def __get_options(self, populate_existing=None, version_check=None, only_load_props=None, refresh_state=None):
+    def _get_options(self, populate_existing=None, 
+                            version_check=None, 
+                            only_load_props=None, 
+                            refresh_state=None):
         if populate_existing:
             self._populate_existing = populate_existing
         if version_check:
@@ -315,7 +340,8 @@ class Query(object):
     def statement(self):
         """The full SELECT statement represented by this Query."""
 
-        return self._compile_context(labels=self._with_labels).statement._annotate({'_halt_adapt': True})
+        return self._compile_context(labels=self._with_labels).\
+                        statement._annotate({'_halt_adapt': True})
 
     def subquery(self):
         """return the full SELECT statement represented by this Query, embedded within an Alias.
@@ -358,7 +384,29 @@ class Query(object):
 
         """
         self._with_labels = True
-
+    
+    @_generative()
+    def enable_assertions(self, value):
+        """Control whether assertions are generated.
+        
+        When set to False, the returned Query will 
+        not assert its state before certain operations, 
+        including that LIMIT/OFFSET has not been applied
+        when filter() is called, no criterion exists
+        when get() is called, and no "from_statement()"
+        exists when filter()/order_by()/group_by() etc.
+        is called.  This more permissive mode is used by 
+        custom Query subclasses to specify criterion or 
+        other modifiers outside of the usual usage patterns.
+        
+        Care should be taken to ensure that the usage 
+        pattern is even possible.  A statement applied
+        by from_statement() will override any criterion
+        set by filter() or order_by(), for example.
+        
+        """
+        self._enable_assertions = value
+        
     @property
     def whereclause(self):
         """The WHERE criterion for this Query."""
@@ -375,7 +423,7 @@ class Query(object):
         """
         self._current_path = path
 
-    @_generative(__no_clauseelement_condition)
+    @_generative(_no_clauseelement_condition)
     def with_polymorphic(self, cls_or_mappers, selectable=None, discriminator=None):
         """Load columns for descendant mappers of this Query's mapper.
 
@@ -438,7 +486,9 @@ class Query(object):
         if hasattr(ident, '__composite_values__'):
             ident = ident.__composite_values__()
 
-        key = self._only_mapper_zero("get() can only be used against a single mapped class.").identity_key_from_primary_key(ident)
+        key = self._only_mapper_zero(
+                    "get() can only be used against a single mapped class."
+                ).identity_key_from_primary_key(ident)
         return self._get(key, ident)
 
     @classmethod
@@ -526,7 +576,11 @@ class Query(object):
                 if isinstance(prop, properties.PropertyLoader) and prop.mapper is self._mapper_zero():
                     break
             else:
-                raise sa_exc.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self._mapper_zero().class_.__name__, instance.__class__.__name__))
+                raise sa_exc.InvalidRequestError(
+                            "Could not locate a property which relates instances "
+                            "of class '%s' to instances of class '%s'" % 
+                            (self._mapper_zero().class_.__name__, instance.__class__.__name__)
+                        )
         else:
             prop = mapper.get_property(property, resolve_synonyms=True)
         return self.filter(prop.compare(operators.eq, instance, value_is_parent=True))
@@ -540,7 +594,7 @@ class Query(object):
 
         self._entities = list(self._entities)
         m = _MapperEntity(self, entity)
-        self.__setup_aliasizers([m])
+        self._setup_aliasizers([m])
 
     def from_self(self, *entities):
         """return a Query that selects from this Query's SELECT statement.
@@ -562,7 +616,7 @@ class Query(object):
         self._statement = self._criterion = None
         self._order_by = self._group_by = self._distinct = False
         self._limit = self._offset = None
-        self.__set_select_from(fromclause)
+        self._set_select_from(fromclause)
 
     def values(self, *columns):
         """Return an iterator yielding result tuples corresponding to the given list of columns"""
@@ -592,20 +646,20 @@ class Query(object):
         _ColumnEntity(self, column)
         # _ColumnEntity may add many entities if the
         # given arg is a FROM clause
-        self.__setup_aliasizers(self._entities[l:])
+        self._setup_aliasizers(self._entities[l:])
 
     def options(self, *args):
         """Return a new Query object, applying the given list of
         MapperOptions.
 
         """
-        return self.__options(False, *args)
+        return self._options(False, *args)
 
     def _conditional_options(self, *args):
-        return self.__options(True, *args)
+        return self._options(True, *args)
 
     @_generative()
-    def __options(self, conditional, *args):
+    def _options(self, conditional, *args):
         # most MapperOptions write to the '_attributes' dictionary,
         # so copy that as well
         self._attributes = self._attributes.copy()
@@ -641,7 +695,7 @@ class Query(object):
         self._params = self._params.copy()
         self._params.update(kwargs)
 
-    @_generative(__no_statement_condition, __no_limit_offset)
+    @_generative(_no_statement_condition, _no_limit_offset)
     def filter(self, criterion):
         """apply the given filtering criterion to the query and return the newly resulting ``Query``
 
@@ -670,7 +724,7 @@ class Query(object):
         return self.filter(sql.and_(*clauses))
 
 
-    @_generative(__no_statement_condition, __no_limit_offset)
+    @_generative(_no_statement_condition, _no_limit_offset)
     @util.accepts_a_list_as_starargs(list_deprecation='pending')
     def order_by(self, *criterion):
         """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``"""
@@ -685,7 +739,7 @@ class Query(object):
             else:
                 self._order_by = self._order_by + criterion
 
-    @_generative(__no_statement_condition, __no_limit_offset)
+    @_generative(_no_statement_condition, _no_limit_offset)
     @util.accepts_a_list_as_starargs(list_deprecation='pending')
     def group_by(self, *criterion):
         """apply one or more GROUP BY criterion to the query and return the newly resulting ``Query``"""
@@ -699,7 +753,7 @@ class Query(object):
         else:
             self._group_by = self._group_by + criterion
 
-    @_generative(__no_statement_condition, __no_limit_offset)
+    @_generative(_no_statement_condition, _no_limit_offset)
     def having(self, criterion):
         """apply a HAVING criterion to the query and return the newly resulting ``Query``."""
 
@@ -867,7 +921,7 @@ class Query(object):
         aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False)
         if kwargs:
             raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys()))
-        return self.__join(props, outerjoin=False, create_aliases=aliased, from_joinpoint=from_joinpoint)
+        return self._join(props, outerjoin=False, create_aliases=aliased, from_joinpoint=from_joinpoint)
 
     @util.accepts_a_list_as_starargs(list_deprecation='pending')
     def outerjoin(self, *props, **kwargs):
@@ -880,19 +934,19 @@ class Query(object):
         aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False)
         if kwargs:
             raise TypeError("unknown arguments: %s" % ','.join(kwargs.iterkeys()))
-        return self.__join(props, outerjoin=True, create_aliases=aliased, from_joinpoint=from_joinpoint)
+        return self._join(props, outerjoin=True, create_aliases=aliased, from_joinpoint=from_joinpoint)
 
-    @_generative(__no_statement_condition, __no_limit_offset)
-    def __join(self, keys, outerjoin, create_aliases, from_joinpoint):
+    @_generative(_no_statement_condition, _no_limit_offset)
+    def _join(self, keys, outerjoin, create_aliases, from_joinpoint):
 
         # copy collections that may mutate so they do not affect
         # the copied-from query.
-        self.__currenttables = set(self.__currenttables)
+        self._currenttables = set(self._currenttables)
         self._polymorphic_adapters = self._polymorphic_adapters.copy()
 
         # start from the beginning unless from_joinpoint is set.
         if not from_joinpoint:
-            self.__reset_joinpoint()
+            self._reset_joinpoint()
 
         clause = replace_clause_index = None
         
@@ -1027,11 +1081,11 @@ class Query(object):
 
                 elif prop:
                     # for joins across plain relation()s, try not to specify the
-                    # same joins twice.  the __currenttables collection tracks
+                    # same joins twice.  the _currenttables collection tracks
                     # what plain mapped tables we've joined to already.
 
-                    if prop.table in self.__currenttables:
-                        if prop.secondary is not None and prop.secondary not in self.__currenttables:
+                    if prop.table in self._currenttables:
+                        if prop.secondary is not None and prop.secondary not in self._currenttables:
                             # TODO: this check is not strong enough for different paths to the same endpoint which
                             # does not use secondary tables
                             raise sa_exc.InvalidRequestError("Can't join to property '%s'; a path to this "
@@ -1040,8 +1094,8 @@ class Query(object):
                         continue
 
                     if prop.secondary:
-                        self.__currenttables.add(prop.secondary)
-                    self.__currenttables.add(prop.table)
+                        self._currenttables.add(prop.secondary)
+                    self._currenttables.add(prop.table)
 
                     if of_type:
                         right_entity = of_type
@@ -1097,7 +1151,7 @@ class Query(object):
         # future joins with from_joinpoint=True join from our established right_entity.
         self._joinpoint = right_entity
 
-    @_generative(__no_statement_condition)
+    @_generative(_no_statement_condition)
     def reset_joinpoint(self):
         """return a new Query reset the 'joinpoint' of this Query reset
         back to the starting mapper.  Subsequent generative calls will
@@ -1107,9 +1161,9 @@ class Query(object):
         the root.
 
         """
-        self.__reset_joinpoint()
+        self._reset_joinpoint()
 
-    @_generative(__no_clauseelement_condition)
+    @_generative(_no_clauseelement_condition)
     def select_from(self, from_obj):
         """Set the `from_obj` parameter of the query and return the newly
         resulting ``Query``.  This replaces the table which this Query selects
@@ -1130,7 +1184,7 @@ class Query(object):
             from_obj = from_obj[-1]
         if not isinstance(from_obj, expression.FromClause):
             raise sa_exc.ArgumentError("select_from() accepts FromClause objects only.")
-        self.__set_select_from(from_obj)
+        self._set_select_from(from_obj)
 
     def __getitem__(self, item):
         if isinstance(item, slice):
@@ -1153,7 +1207,7 @@ class Query(object):
         else:
             return list(self[item:item+1])[0]
 
-    @_generative(__no_statement_condition)
+    @_generative(_no_statement_condition)
     def slice(self, start, stop):
         """apply LIMIT/OFFSET to the ``Query`` based on a range and return the newly resulting ``Query``."""
         if start is not None and stop is not None:
@@ -1164,7 +1218,7 @@ class Query(object):
         elif start is not None and stop is None:
             self._offset = (self._offset or 0) + start
 
-    @_generative(__no_statement_condition)
+    @_generative(_no_statement_condition)
     def limit(self, limit):
         """Apply a ``LIMIT`` to the query and return the newly resulting
 
@@ -1173,7 +1227,7 @@ class Query(object):
         """
         self._limit = limit
 
-    @_generative(__no_statement_condition)
+    @_generative(_no_statement_condition)
     def offset(self, offset):
         """Apply an ``OFFSET`` to the query and return the newly resulting
         ``Query``.
@@ -1181,7 +1235,7 @@ class Query(object):
         """
         self._offset = offset
 
-    @_generative(__no_statement_condition)
+    @_generative(_no_statement_condition)
     def distinct(self):
         """Apply a ``DISTINCT`` to the query and return the newly resulting
         ``Query``.
@@ -1197,7 +1251,7 @@ class Query(object):
         """
         return list(self)
 
-    @_generative(__no_clauseelement_condition)
+    @_generative(_no_clauseelement_condition)
     def from_statement(self, statement):
         """Execute the given SELECT statement and return results.
 
@@ -1398,7 +1452,7 @@ class Query(object):
 
         if refresh_state is None:
             q = self._clone()
-            q.__no_criterion_condition("get")
+            q._no_criterion_condition("get")
         else:
             q = self._clone()
 
@@ -1420,7 +1474,7 @@ class Query(object):
 
         if lockmode is not None:
             q._lockmode = lockmode
-        q.__get_options(
+        q._get_options(
             populate_existing=bool(refresh_state),
             version_check=(lockmode is not None),
             only_load_props=only_load_props,
index d2cf50d84978baee9078046d79c2cd39cf8f850c..6aa77868ea2a8e2019455dde9cef253d21481006 100644 (file)
@@ -74,7 +74,7 @@ class FalseDiscriminatorTest(_base.MappedTest):
         global t1
         t1 = Table('t1', metadata, 
                     Column('id', Integer, primary_key=True), 
-                    Column('type', Integer, nullable=False)
+                    Column('type', Boolean, nullable=False)
                 )
         
     def test_false_on_sub(self):
index b079e35f220ed382e62851a8b90628f73284814d..67e933efb764e0e99a39d3120dea621864496874 100644 (file)
@@ -225,6 +225,11 @@ class InvalidGenerationsTest(QueryTest):
 
             assert_raises(sa_exc.InvalidRequestError, q.having, 'foo')
     
+            q.enable_assertions(False).join("addresses")
+            q.enable_assertions(False).filter(User.name=='ed')
+            q.enable_assertions(False).order_by('foo')
+            q.enable_assertions(False).group_by('foo')
+            
     def test_no_from(self):
         s = create_session()
     
@@ -236,6 +241,10 @@ class InvalidGenerationsTest(QueryTest):
         
         q = s.query(User).order_by(User.id)
         assert_raises(sa_exc.InvalidRequestError, q.select_from, users)
+
+        assert_raises(sa_exc.InvalidRequestError, q.select_from, users)
+        
+        q.enable_assertions(False).select_from(users)
         
         # this is fine, however
         q.from_self()