]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
refactoring to allow column.label() to work in selects, etc.
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 Jan 2006 00:33:15 +0000 (00:33 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 Jan 2006 00:33:15 +0000 (00:33 +0000)
fixed superfluous codeline in ForeignKey

lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/selectable.py

index c55d9034a030c7e068ba6ccf0994952f1bbb1ec6..6679c8b0575718d8705785d9aae2d0045a73730f 100644 (file)
@@ -367,10 +367,7 @@ class ForeignKey(SchemaItem):
         
     def references(self, table):
         """returns True if the given table is referenced by this ForeignKey."""
-        try:
-            return table._get_col_by_original(self.column) is not None
-        except:
-            x = self._init_column()
+        return table._get_col_by_original(self.column) is not None
         
     def _init_column(self):
         # ForeignKey inits its remote column as late as possible, so tables can
index c37a8d620f97301ca5bb83967f9f2402e395265e..c3048de29422fd928da31405b49f3a3d5d401fb1 100644 (file)
@@ -121,12 +121,12 @@ def or_(*clauses):
 def not_(clause):
     """returns a negation of the given clause, i.e. NOT(clause).  the ~ operator can be used as well."""
     clause.parens=True
-    return BinaryClause(TextClause("NOT"), clause, None)
+    return BooleanExpression(TextClause("NOT"), clause, None)
             
         
 def exists(*args, **params):
     s = select(*args, **params)
-    return BinaryClause(TextClause("EXISTS"), s, None)
+    return BooleanExpression(TextClause("EXISTS"), s, None)
 
 def union(*selects, **params):
     return _compound_select('UNION', *selects, **params)
@@ -441,15 +441,15 @@ class CompareMixin(object):
         return lambda other: self._compare(operator, other)
     # and here come the math operators:
     def __add__(self, other):
-        return self._compare('+', other)
+        return self._operate('+', other)
     def __sub__(self, other):
-        return self._compare('-', other)
+        return self._operate('-', other)
     def __mul__(self, other):
-        return self._compare('*', other)
+        return self._operate('*', other)
     def __div__(self, other):
-        return self._compare('/', other)
+        return self._operate('/', other)
     def __truediv__(self, other):
-        return self._compare('/', other)
+        return self._operate('/', other)
     def _bind_param(self, obj):
         return BindParamClause('literal', obj, shortname=None, type=self.type)
     def _compare(self, operator, obj):
@@ -457,12 +457,24 @@ class CompareMixin(object):
             if obj is None:
                 if operator != '=':
                     raise "Only '=' operator can be used with NULL"
-                return BinaryClause(self, null(), 'IS')
+                return BooleanExpression(self, null(), 'IS')
             else:
                 obj = self._bind_param(obj)
 
-        return BinaryClause(self, obj, operator, type=obj.type)
-
+        return BooleanExpression(self._compare_self(), obj, operator, type=self._compare_type(obj))
+    def _operate(self, operator, obj):
+        if _is_literal(obj):
+            obj = self._bind_param(obj)
+        return BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj))
+    def _compare_self(self):
+        """allows ColumnImpl to return its Column object for usage in ClauseElements, all others to
+        just return self"""
+        return self
+    def _compare_type(self, obj):
+        """allows subclasses to override the type used in constructing BinaryClause objects.  Default return
+        value is the type of the given object."""
+        return obj.type
+        
 class Selectable(ClauseElement):
     """represents a column list-holding object."""
 
@@ -482,11 +494,27 @@ class Selectable(ClauseElement):
 
 
 class ColumnElement(Selectable, CompareMixin):
-    """represents a column element within the list of a Selectable's columns."""
+    """represents a column element within the list of a Selectable's columns.  Provides 
+    default implementations for the things a "column" needs, including a "primary_key" flag,
+    a "foreign_key" accessor, an "original" accessor which represents the ultimate column
+    underlying a string of labeled/select-wrapped columns, and "columns" which returns a list
+    of the single column, providing the same list-based interface as a FromClause."""
     primary_key = property(lambda self:getattr(self, '_primary_key', False))
     foreign_key = property(lambda self:getattr(self, '_foreign_key', False))
     original = property(lambda self:getattr(self, '_original', self))
     columns = property(lambda self:[self])
+    def _make_proxy(self, selectable, name=None):
+        """creates a new ColumnElement representing this ColumnElement as it appears in the select list
+        of an enclosing selectable.  The default implementation returns a ColumnClause if a name is given,
+        else just returns self.  This has various mechanics with schema.Column and sql.Label so that 
+        Column objects as well as non-column objects like Function and BinaryClause can both appear in the 
+        select list of an enclosing selectable."""
+        if name is not None:
+            co = ColumnClause(name, selectable)
+            selectable.columns[name]= co
+            return co
+        else:
+            return self
 
 class FromClause(Selectable):
     """represents an element that can be used within the FROM clause of a SELECT statement."""
@@ -520,7 +548,7 @@ class FromClause(Selectable):
         (i.e. an "original" column), return the Column object from this 
         Selectable which corresponds to that original Column, or None if this Selectable
         does not contain the column."""
-        return self.original_columns[column.original]
+        return self.original_columns.get(column.original, None)
     def _get_exported_attribute(self, name):
         try:
             return getattr(self, name)
@@ -706,7 +734,7 @@ class CompoundClause(ClauseList):
         else:
             return False
                 
-class Function(ClauseList, CompareMixin):
+class Function(ClauseList, ColumnElement):
     """describes a SQL function. extends ClauseList to provide comparison operators."""
     def __init__(self, name, *clauses, **kwargs):
         self.name = name
@@ -722,29 +750,19 @@ class Function(ClauseList, CompareMixin):
         self.clauses.append(clause)
     def copy_container(self):
         clauses = [clause.copy_container() for clause in self.clauses]
-        return Function(self.name, label=self.label, type=self.type, *clauses)
+        return Function(self.name, type=self.type, *clauses)
     def accept_visitor(self, visitor):
         for c in self.clauses:
             c.accept_visitor(visitor)
         visitor.visit_function(self)
-    def _compare(self, operator, obj):
-        if _is_literal(obj):
-            if obj is None:
-                if operator != '=':
-                    raise "Only '=' operator can be used with NULL"
-                return BinaryClause(self, null(), 'IS')
-            else:
-                obj = BindParamClause(self.name, obj, shortname=self.name, type=self.type)
-
-        return BinaryClause(self, obj, operator)
-    def _make_proxy(self, selectable, name = None):
-        return self
+    def _bind_param(self, obj):
+        return BindParamClause(self.name, obj, shortname=self.name, type=self.type)
     def select(self):
         return select([self])
     def hash_key(self):
         return self.name + "(" + string.join([c.hash_key() for c in self.clauses], ", ") + ")"
             
-class BinaryClause(ClauseElement, CompareMixin):
+class BinaryClause(ClauseElement):
     """represents two clauses with an operator in between"""
     def __init__(self, left, right, operator, type=None):
         self.left = left
@@ -772,6 +790,16 @@ class BinaryClause(ClauseElement, CompareMixin):
             isinstance(other, BinaryClause) and self.operator == other.operator and 
             self.left.compare(other.left) and self.right.compare(other.right)
         )
+
+class BooleanExpression(BinaryClause):
+    """represents a boolean expression, which is only useable in WHERE criterion."""
+    pass
+class BinaryExpression(BinaryClause, ColumnElement):
+    """represents a binary expression, which can be in a WHERE criterion or in the column list 
+    of a SELECT.  By adding "ColumnElement" to its inherited list, it becomes a Selectable
+    unit which can be placed in the column list of a SELECT."""
+    pass
+    
         
 class Join(FromClause):
     def __init__(self, left, right, onclause=None, isouter = False):
@@ -873,7 +901,7 @@ class Alias(FromClause):
         return self.selectable.columns
 
     def hash_key(self):
-        return "Alias(%s, %s)" % (repr(self.selectable.hash_key()), repr(self.name))
+        return "Alias(%s, %s)" % (self.selectable.hash_key(), repr(self.name))
 
     def accept_visitor(self, visitor):
         self.selectable.accept_visitor(visitor)
@@ -897,17 +925,15 @@ class Label(ColumnElement):
         obj.parens=True
     key = property(lambda s: s.name)
     _label = property(lambda s: s.name)
+    original = property(lambda s:s.obj.original)
     def accept_visitor(self, visitor):
         self.obj.accept_visitor(visitor)
         visitor.visit_label(self)
     def _get_from_objects(self):
         return self.obj._get_from_objects()
     def _make_proxy(self, selectable, name = None):
-        # TODO: this make_proxy needs foreign_key, primary_key support
-        # from its underlying column, if any
-        cc = ColumnClause(self.name, selectable)
-        selectable.c[self.name] = cc
-        return cc
+        return self.obj._make_proxy(selectable, name=self.name)
+        
     def hash_key(self):
         return "Label(%s, %s)" % (self.name, self.obj.hash_key())
      
@@ -936,19 +962,11 @@ class ColumnClause(ColumnElement):
     def _get_from_objects(self):
         return []
 
-    def _compare(self, operator, obj):
-        if _is_literal(obj):
-            if obj is None:
-                if operator != '=':
-                    raise "Only '=' operator can be used with NULL"
-                return BinaryClause(self, null(), 'IS')
-            elif self.table.name is None:
-                obj = BindParamClause(self.text, obj, shortname=self.text, type=self.type)
-            else:
-                obj = BindParamClause(self.table.name + "_" + self.text, obj, shortname = self.text, type=self.type)
-
-        return BinaryClause(self, obj, operator)
-
+    def _bind_param(self, obj):
+        if self.table.name is None:
+            return BindParamClause(self.text, obj, shortname=self.text, type=self.type)
+        else:
+            return BindParamClause(self.table.name + "_" + self.text, obj, shortname = self.text, type=self.type)
     def _make_proxy(self, selectable, name = None):
         c = ColumnClause(self.text or name, selectable)
         selectable.columns[c.key] = c
@@ -992,18 +1010,13 @@ class ColumnImpl(ColumnElement):
             return BindParamClause(self.name, obj, shortname = self.name, type = self.column.type)
         else:
             return BindParamClause(self.column.table.name + "_" + self.name, obj, shortname = self.name, type = self.column.type)
-
-    def _compare(self, operator, obj):
-        if _is_literal(obj):
-            if obj is None:
-                if operator != '=':
-                    raise "Only '=' operator can be used with NULL"
-                return BinaryClause(self.column, null(), 'IS')
-            else:
-                obj = self._bind_param(obj)
-
-        return BinaryClause(self.column, obj, operator, type=self.column.type)
-
+    def _compare_self(self):
+        """allows ColumnImpl to return its Column object for usage in ClauseElements, all others to
+        just return self"""
+        return self.column
+    def _compare_type(self, obj):
+        return self.column.type
+        
     def compile(self, engine = None, parameters = None, typemap=None):
         if engine is None:
             engine = self.engine
@@ -1226,7 +1239,7 @@ class Select(SelectBaseMixin, FromClause):
         for f in column._get_from_objects():
             f.accept_visitor(self._correlator)
         column._process_from_dict(self._froms, False)
-
+        
     def _exportable_columns(self):
         return self._raw_columns
     def _proxy_column(self, column):
@@ -1259,11 +1272,11 @@ class Select(SelectBaseMixin, FromClause):
         try:
             return "Select(%s)" % string.join(
                 [
-                    "columns=" + repr([util.hash_key(c) for c in self._raw_columns]),
+                    "columns=" + string.join([util.hash_key(c) for c in self._raw_columns],','),
                     "where=" + util.hash_key(self.whereclause),
-                    "from=" + repr([util.hash_key(f) for f in self.froms]),
+                    "from=" + string.join([util.hash_key(f) for f in self.froms],','),
                     "having=" + util.hash_key(self.having),
-                    "clauses=" + repr([util.hash_key(c) for c in self.clauses])
+                    "clauses=" + string.join([util.hash_key(c) for c in self.clauses], ',')
                 ] + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['use_labels', 'distinct', 'limit', 'offset']], ","
             ) 
         finally:
index 37b1f28eca4f61f9f39f70d83592625240838138..aefef25e06944ee7e1f18d41b97ea157c94e9527 100755 (executable)
@@ -74,7 +74,11 @@ class SelectableTest(testbase.AssertMixin):
         a = select([table.c.col1.label('acol1'), table.c.col2.label('acol2'), table.c.col3.label('acol3')])\r
         print str(a)\r
         print [c for c in a.columns]\r
+        print str(a.select())\r
         j = join(a, table2)\r
+        criterion = a.c.acol1 == table2.c.col2\r
+        print str(j)\r
+        self.assert_(criterion.compare(j.onclause))\r
         \r
     def testselectaliaslabels(self):\r
         a = table2.select(use_labels=True).alias('a')\r