]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
flattened _get_from_objects() into a descriptor/class-bound attribute
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 Nov 2008 21:34:59 +0000 (21:34 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 Nov 2008 21:34:59 +0000 (21:34 +0000)
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/util.py

index 0ccad76ffadf00ccfda102ce2a80a8db7f7328e0..0909587954b88c3646c42ffd999059a8001d0e4b 100644 (file)
@@ -1016,9 +1016,9 @@ class MSSQLCompiler(compiler.DefaultCompiler):
             return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs)
         else:
             if (binary.operator in (operator.eq, operator.ne)) and (
-                (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._SelectBaseMixin)) or \
-                (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._SelectBaseMixin)) or \
-                 isinstance(binary.left, expression._SelectBaseMixin) or isinstance(binary.right, expression._SelectBaseMixin)):
+                (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \
+                (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \
+                 isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)):
                 op = binary.operator == operator.eq and "IN" or "NOT IN"
                 return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
             return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
index f46a4bd4e3463a9744dfc02664124ef1c9d3df08..9871f6e0e7a416eacaad2872f9425f5b34044253 100644 (file)
@@ -679,10 +679,9 @@ class OracleDialect(default.DefaultDialect):
 
 class _OuterJoinColumn(sql.ClauseElement):
     __visit_name__ = 'outer_join_column'
+    
     def __init__(self, column):
         self.column = column
-    def _get_from_objects(self, **kwargs):
-        return []
 
 class OracleCompiler(compiler.DefaultCompiler):
     """Oracle compiler modifies the lexical structure of Select
index 81250706bea20e927774440790eb070c46ac1ef4..82a698fd07e1e7be56332e8b706ce8c0c9c721e7 100644 (file)
@@ -1791,13 +1791,13 @@ class _ColumnEntity(_QueryEntity):
         # of FROMs for the overall expression - this helps
         # subqueries which were built from ORM constructs from
         # leaking out their entities into the main select construct
-        actual_froms = set(column._get_from_objects())
+        actual_froms = set(column._from_objects)
 
         self.entities = util.OrderedSet(
             elem._annotations['parententity']
             for elem in visitors.iterate(column, {})
             if 'parententity' in elem._annotations
-            and actual_froms.intersection(elem._get_from_objects())
+            and actual_froms.intersection(elem._from_objects)
             )
 
         if self.entities:
index 9b847c78b497973c389a54be3ccff756137c4087..7ca3d7b9e2689dcdc5cfdc51756216e53156b615 100644 (file)
@@ -896,8 +896,8 @@ def _compound_select(keyword, *selects, **kwargs):
 def _is_literal(element):
     return not isinstance(element, Visitable) and not hasattr(element, '__clause_element__')
 
-def _from_objects(*elements, **kwargs):
-    return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements])
+def _from_objects(*elements):
+    return itertools.chain(*[element._from_objects for element in elements])
 
 def _labeled(element):
     if not hasattr(element, 'name'):
@@ -967,6 +967,7 @@ class ClauseElement(Visitable):
 
     _annotations = {}
     supports_execution = False
+    _from_objects = []
     
     def _clone(self):
         """Create a shallow copy of this ClauseElement.
@@ -1010,15 +1011,6 @@ class ClauseElement(Visitable):
         d.pop('_is_clone_of', None)
         return d
         
-    def _get_from_objects(self, **modifiers):
-        """Return objects represented in this ``ClauseElement`` that
-        should be added to the ``FROM`` list of a query, when this
-        ``ClauseElement`` is placed in the column clause of a
-        ``Select`` statement.
-        
-        """
-        raise NotImplementedError(repr(self))
-    
     def _annotate(self, values):
         """return a copy of this ClauseElement with the given annotations dictionary."""
 
@@ -1544,6 +1536,8 @@ class _CompareMixin(ColumnOperators):
             return other.__clause_element__()
         elif not isinstance(other, ClauseElement):
             return self._bind_param(other)
+        elif isinstance(other, _SelectBaseMixin):
+            return other.as_scalar()
         else:
             return other
 
@@ -1754,9 +1748,6 @@ class FromClause(Selectable):
     quote = None
     schema = None
 
-    def _get_from_objects(self, **modifiers):
-        return []
-
     def count(self, whereclause=None, **params):
         """return a SELECT COUNT generated against this ``FromClause``."""
 
@@ -1962,9 +1953,6 @@ class _BindParamClause(ColumnElement):
             self.unique = True
             self.key = _generated_label("%%(%d %s)s" % (id(self), self._orig_key or 'param'))
 
-    def _get_from_objects(self, **modifiers):
-        return []
-
     def bind_processor(self, dialect):
         return self.type.dialect_impl(dialect).bind_processor(dialect)
 
@@ -2009,8 +1997,6 @@ class _TypeClause(ClauseElement):
     def __init__(self, type):
         self.type = type
 
-    def _get_from_objects(self, **modifiers):
-        return []
 
 class _TextClause(ClauseElement):
     """Represent a literal SQL text fragment.
@@ -2064,8 +2050,6 @@ class _TextClause(ClauseElement):
     def get_children(self, **kwargs):
         return self.bindparams.values()
 
-    def _get_from_objects(self, **modifiers):
-        return []
     
 class _Null(ColumnElement):
     """Represent the NULL keyword in a SQL statement.
@@ -2075,11 +2059,8 @@ class _Null(ColumnElement):
     """
 
     def __init__(self):
-        ColumnElement.__init__(self)
         self.type = sqltypes.NULLTYPE
 
-    def _get_from_objects(self, **modifiers):
-        return []
 
 class ClauseList(ClauseElement):
     """Describe a list of clauses, separated by an operator.
@@ -2124,8 +2105,9 @@ class ClauseList(ClauseElement):
     def get_children(self, **kwargs):
         return self.clauses
 
-    def _get_from_objects(self, **modifiers):
-        return list(itertools.chain(*[c._get_from_objects(**modifiers) for c in self.clauses]))
+    @property
+    def _from_objects(self):
+        return list(itertools.chain(*[c._from_objects for c in self.clauses]))
 
     def self_group(self, against=None):
         if self.group and self.operator != against and operators.is_precedent(self.operator, against):
@@ -2168,7 +2150,6 @@ class _CalculatedClause(ColumnElement):
     __visit_name__ = 'calculatedclause'
 
     def __init__(self, name, *clauses, **kwargs):
-        ColumnElement.__init__(self)
         self.name = name
         self.type = sqltypes.to_instance(kwargs.get('type_', None))
         self._bind = kwargs.get('bind', None)
@@ -2199,8 +2180,9 @@ class _CalculatedClause(ColumnElement):
     def get_children(self, **kwargs):
         return self.clause_expr,
 
-    def _get_from_objects(self, **modifiers):
-        return self.clauses._get_from_objects(**modifiers)
+    @property
+    def _from_objects(self):
+        return self.clauses._from_objects
 
     def _bind_param(self, obj):
         return _BindParamClause(self.name, obj, type_=self.type, unique=True)
@@ -2252,7 +2234,6 @@ class _Function(_CalculatedClause, FromClause):
 class _Cast(ColumnElement):
 
     def __init__(self, clause, totype, **kwargs):
-        ColumnElement.__init__(self)
         self.type = sqltypes.to_instance(totype)
         self.clause = _literal_as_binds(clause, None)
         self.typeclause = _TypeClause(self.type)
@@ -2264,13 +2245,13 @@ class _Cast(ColumnElement):
     def get_children(self, **kwargs):
         return self.clause, self.typeclause
 
-    def _get_from_objects(self, **modifiers):
-        return self.clause._get_from_objects(**modifiers)
+    @property
+    def _from_objects(self):
+        return self.clause._from_objects
 
 
 class _UnaryExpression(ColumnElement):
     def __init__(self, element, operator=None, modifier=None, type_=None, negate=None):
-        ColumnElement.__init__(self)
         self.operator = operator
         self.modifier = modifier
 
@@ -2278,8 +2259,9 @@ class _UnaryExpression(ColumnElement):
         self.type = sqltypes.to_instance(type_)
         self.negate = negate
 
-    def _get_from_objects(self, **modifiers):
-        return self.element._get_from_objects(**modifiers)
+    @property
+    def _from_objects(self):
+        return self.element._from_objects
 
     def _copy_internals(self, clone=_clone):
         self.element = clone(self.element)
@@ -2319,7 +2301,6 @@ class _BinaryExpression(ColumnElement):
     """Represent an expression that is ``LEFT <operator> RIGHT``."""
 
     def __init__(self, left, right, operator, type_=None, negate=None, modifiers=None):
-        ColumnElement.__init__(self)
         self.left = _literal_as_text(left).self_group(against=operator)
         self.right = _literal_as_text(right).self_group(against=operator)
         self.operator = operator
@@ -2330,8 +2311,9 @@ class _BinaryExpression(ColumnElement):
         else:
             self.modifiers = modifiers
 
-    def _get_from_objects(self, **modifiers):
-        return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
+    @property
+    def _from_objects(self):
+        return self.left._from_objects + self.right._from_objects
 
     def _copy_internals(self, clone=_clone):
         self.left = clone(self.left)
@@ -2379,7 +2361,8 @@ class _BinaryExpression(ColumnElement):
 
 class _Exists(_UnaryExpression):
     __visit_name__ = _UnaryExpression.__visit_name__
-
+    _from_objects = []
+    
     def __init__(self, *args, **kwargs):
         if args and isinstance(args[0], _SelectBaseMixin):
             s = args[0]
@@ -2398,9 +2381,6 @@ class _Exists(_UnaryExpression):
         e.element = self.element.correlate(fromclause).self_group()
         return e
 
-    def _get_from_objects(self, **modifiers):
-        return []
-
     def select_from(self, clause):
         """return a new exists() construct with the given expression set as its FROM clause."""
     
@@ -2524,11 +2504,12 @@ class Join(FromClause):
     def _hide_froms(self):
         return itertools.chain(*[_from_objects(x.left, x.right) for x in self._cloned_set])
 
-    def _get_from_objects(self, **modifiers):
+    @property
+    def _from_objects(self):
         return [self] + \
-                self.onclause._get_from_objects(**modifiers) + \
-                self.left._get_from_objects(**modifiers) + \
-                self.right._get_from_objects(**modifiers)
+                self.onclause._from_objects + \
+                self.left._from_objects + \
+                self.right._from_objects
 
 class Alias(FromClause):
     """Represents an table or selectable alias (AS).
@@ -2588,7 +2569,8 @@ class Alias(FromClause):
         if aliased_selectables:
             yield self.element
 
-    def _get_from_objects(self, **modifiers):
+    @property
+    def _from_objects(self):
         return [self]
 
     @property
@@ -2599,7 +2581,6 @@ class _Grouping(ColumnElement):
     """Represent a grouping within a column expression"""
 
     def __init__(self, element):
-        ColumnElement.__init__(self)
         self.element = element
         self.type = getattr(element, 'type', None)
 
@@ -2617,8 +2598,9 @@ class _Grouping(ColumnElement):
     def get_children(self, **kwargs):
         return self.element,
 
-    def _get_from_objects(self, **modifiers):
-        return self.element._get_from_objects(**modifiers)
+    @property
+    def _from_objects(self):
+        return self.element._from_objects
 
     def __getattr__(self, attr):
         return getattr(self.element, attr)
@@ -2652,8 +2634,9 @@ class _FromGrouping(FromClause):
     def _copy_internals(self, clone=_clone):
         self.element = clone(self.element)
 
-    def _get_from_objects(self, **modifiers):
-        return self.element._get_from_objects(**modifiers)
+    @property
+    def _from_objects(self):
+        return self.element._from_objects
 
     def __getattr__(self, attr):
         return getattr(self.element, attr)
@@ -2704,8 +2687,9 @@ class _Label(ColumnElement):
     def _copy_internals(self, clone=_clone):
         self.element = clone(self.element)
 
-    def _get_from_objects(self, **modifiers):
-        return self.element._get_from_objects(**modifiers)
+    @property
+    def _from_objects(self):
+        return self.element._from_objects
 
     def _make_proxy(self, selectable, name = None):
         if isinstance(self.element, (Selectable, ColumnElement)):
@@ -2743,7 +2727,6 @@ class _ColumnClause(_Immutable, ColumnElement):
       
     """
     def __init__(self, text, selectable=None, type_=None, is_literal=False):
-        ColumnElement.__init__(self)
         self.key = self.name = text
         self.table = selectable
         self.type = sqltypes.to_instance(type_)
@@ -2783,7 +2766,8 @@ class _ColumnClause(_Immutable, ColumnElement):
         else:
             return super(_ColumnClause, self).label(name)
 
-    def _get_from_objects(self, **modifiers):
+    @property
+    def _from_objects(self):
         if self.table:
             return [self.table]
         else:
@@ -2858,7 +2842,8 @@ class TableClause(_Immutable, FromClause):
     def delete(self, whereclause=None, **kwargs):
         return delete(self, whereclause, **kwargs)
 
-    def _get_from_objects(self, **modifiers):
+    @property
+    def _from_objects(self):
         return [self]
 
 @util.decorator
@@ -2994,15 +2979,15 @@ class _SelectBaseMixin(object):
                 clauses = list(self._group_by_clause) + list(clauses)
             self._group_by_clause = ClauseList(*clauses)
 
-    def _get_from_objects(self, is_where=False, **modifiers):
-        if is_where:
-            return []
-        else:
-            return [self]
+    @property
+    def _from_objects(self):
+        return [self]
 
+        
 class _ScalarSelect(_Grouping):
     __visit_name__ = 'grouping'
-
+    _from_objects = []
+    
     def __init__(self, element):
         self.element = element
         cols = list(element.c)
@@ -3023,9 +3008,6 @@ class _ScalarSelect(_Grouping):
     def _make_proxy(self, selectable, name):
         return list(self.inner_columns)[0]._make_proxy(selectable, name)
 
-    def _get_from_objects(self, **modifiers):
-        return []
-
 class CompoundSelect(_SelectBaseMixin, FromClause):
     def __init__(self, keyword, *selects, **kwargs):
         self._should_correlate = kwargs.pop('correlate', False)
@@ -3125,7 +3107,7 @@ class Select(_SelectBaseMixin, FromClause):
 
         if whereclause:
             self._whereclause = _literal_as_text(whereclause)
-            self._froms.update(_from_objects(self._whereclause, is_where=True))
+            self._froms.update(_from_objects(self._whereclause))
         else:
             self._whereclause = None
 
@@ -3348,7 +3330,7 @@ class Select(_SelectBaseMixin, FromClause):
 
         """
         whereclause = _literal_as_text(whereclause)
-        self._froms = self._froms.union(_from_objects(whereclause, is_where=True))
+        self._froms = self._froms.union(_from_objects(whereclause))
         
         if self._whereclause is not None:
             self._whereclause = and_(self._whereclause, whereclause)
index 68cd4adc63ab49a8170b02c47a716afa4bdc9351..f2f2cdd79c900088c4645b5cdd6107caa1de01f1 100644 (file)
@@ -1318,7 +1318,6 @@ class memoized_instancemethod(object):
         oneshot.__doc__ = self.__doc__
         return oneshot
 
-
 def reset_memoized(instance, name):
     try:
         del instance.__dict__[name]