]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added compare function to the more basic expression objects
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 5 Jan 2006 05:44:10 +0000 (05:44 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 5 Jan 2006 05:44:10 +0000 (05:44 +0000)
adding priamry_key/foreign_keys to selects, alias etc to increase their useability for relating them to tables
improved _get_col_by_original to double-check the column it finds

lib/sqlalchemy/sql.py

index 486fa304ea379eaa23f4ad72f2da0d919fa1fd6d..2a1b6829f8231a86bda2f16d2805a61a89f206cb 100644 (file)
@@ -303,6 +303,13 @@ class ClauseElement(object):
             data.setdefault(f.id, f)
         if asfrom:
             data[self.id] = self
+    def compare(self, other):
+        """compares this ClauseElement to the given ClauseElement.
+        
+        Subclasses should override the default behavior, which is a straight
+        identity comparison."""
+        return self is other
+        
     def accept_visitor(self, visitor):
         """accepts a ClauseVisitor and calls the appropriate visit_xxx method."""
         raise NotImplementedError(repr(self))
@@ -521,6 +528,12 @@ class BindParamClause(ClauseElement, CompareMixin):
         return "BindParam(%s, %s, %s)" % (repr(self.key), repr(self.value), repr(self.shortname))
     def typeprocess(self, value, engine):
         return self.type.convert_bind_param(value, engine)
+    def compare(self, other):
+        """compares this BindParamClause to the given clause.
+        
+        Since compare() is meant to compare statement syntax, this method
+        returns True if the two BindParamClauses have just the same type."""
+        return isinstance(other, BindParamClause) and other.type.__class__ == self.type.__class__
             
 class TextClause(ClauseElement):
     """represents literal a SQL text fragment.  public constructor is the 
@@ -580,6 +593,17 @@ class ClauseList(ClauseElement):
         for c in self.clauses:
             f += c._get_from_objects()
         return f
+    def compare(self, other):
+        """compares this ClauseList to the given ClauseList, including
+        a comparison of all the clause items."""
+        if isinstance(other, ClauseList) and len(self.clauses) == len(other.clauses):
+            for i in range(0, len(self.clauses)):
+                if not self.clauses[i].compare(other.clauses[i]):
+                    return False
+            else:
+                return True
+        else:
+            return False
 
 class CompoundClause(ClauseList):
     """represents a list of clauses joined by an operator, such as AND or OR.  
@@ -606,8 +630,20 @@ class CompoundClause(ClauseList):
         return f
     def hash_key(self):
         return string.join([c.hash_key() for c in self.clauses], self.operator or " ")
-
-    
+    def compare(self, other):
+        """compares this CompoundClause to the given item.  
+        
+        In addition to the regular comparison, has the special case that it 
+        returns True if this CompoundClause has only one item, and that 
+        item matches the given item."""
+        if not isinstance(other, CompoundClause):
+            if len(self.clauses) == 1:
+                return self.clauses[0].compare(other)
+        if ClauseList.compare(self, other):
+            return self.operator == other.operator
+        else:
+            return False
+                
 class Function(ClauseList, CompareMixin):
     """describes a SQL function. extends ClauseList to provide comparison operators."""
     def __init__(self, name, *clauses, **kwargs):
@@ -641,7 +677,6 @@ class Function(ClauseList, CompareMixin):
         return BinaryClause(self, obj, operator)
     def _make_proxy(self, selectable, name = None):
         return self
-
         
 class BinaryClause(ClauseElement, CompareMixin):
     """represents two clauses with an operator in between"""
@@ -665,7 +700,13 @@ class BinaryClause(ClauseElement, CompareMixin):
         c = self.left
         self.left = self.right
         self.right = c
-
+    def compare(self, other):
+        """compares this BinaryClause against the given BinaryClause."""
+        return (
+            isinstance(other, BinaryClause) and self.operator == other.operator and 
+            self.left.compare(other.left) and self.right.compare(other.right)
+        )
+        
 class Join(FromClause):
     # TODO: put "using" + "natural" concepts in here and make "onclause" optional
     def __init__(self, left, right, onclause=None, isouter = False, allcols = True):
@@ -713,7 +754,7 @@ class Join(FromClause):
 
     def _get_col_by_original(self, column):
         for c in self.columns:
-            if c.original is column:
+            if c.original is column.original:
                 return c
         else:
             return None
@@ -753,7 +794,6 @@ class Alias(FromClause):
     def __init__(self, selectable, alias = None):
         self.selectable = selectable
         self._columns = util.OrderedProperties()
-        self.foreign_keys = []
         if alias is None:
             n = getattr(selectable, 'name')
             if n is None:
@@ -770,12 +810,17 @@ class Alias(FromClause):
             co._make_proxy(self)
 
     primary_key = property (lambda self: [c for c in self.columns if c.primary_key])
-    
+    foreign_keys = property(lambda s:s.selectable.foreign_keys)
+
     def hash_key(self):
         return "Alias(%s, %s)" % (repr(self.selectable.hash_key()), repr(self.name))
 
     def _get_col_by_original(self, column):
-        return self.columns.get(column.key, None)
+        c = self.columns.get(column.key, None)
+        if c is not None and c.original is column.original:
+            return c
+        else:
+            return None
 
     def accept_visitor(self, visitor):
         self.selectable.accept_visitor(visitor)
@@ -874,8 +919,12 @@ class ColumnImpl(ColumnElement):
     def copy_container(self):
         return self.column
 
+    def compare(self, other):
+        """compares this ColumnImpl's column to the other given Column"""
+        return self.column is other
+        
     def _get_col_by_original(self, column):
-        if self.column.original is column:
+        if self.column.original is column.original:
             return self.column
         else:
             return None
@@ -932,7 +981,7 @@ class TableImpl(FromClause):
           col = self.columns[column.key]
         except KeyError:
           return None
-        if col.original is column:
+        if col.original is column.original:
           return col
         else:
           return None
@@ -1016,6 +1065,8 @@ class CompoundSelect(SelectBaseMixin, FromClause):
         if group_by:
             self.group_by(*group_by)
 
+    primary_key = property(lambda s:s.selects[0].primary_key)
+    foreign_keys = property(lambda s:s.selects[0].foreign_keys)
     columns = property(lambda s:s.selects[0].columns)
     def accept_visitor(self, visitor):
         for tup in self.clauses:
@@ -1119,9 +1170,13 @@ class Select(SelectBaseMixin, FromClause):
             
     def _get_col_by_original(self, column):
         if self.use_labels:
-            return self.columns.get(column._label,None)
+            c = self.columns.get(column._label,None)
         else:
-            return self.columns.get(column.key,None)
+            c = self.columns.get(column.key,None)
+        if c is not None and c.original is column.original:
+            return c
+        else:
+            return None
 
     def _pks(self):
         ret = {}
@@ -1130,8 +1185,15 @@ class Select(SelectBaseMixin, FromClause):
                 if c.primary_key:
                     ret[c] = c
         return ret.keys()
-    primary_key = property (lambda self: self._pks())
-
+    def _fks(self):
+        ret = []
+        for from_obj in self._get_froms():
+            for fk in from_obj.foreign_keys:
+                ret.append(fk)
+        return ret
+    primary_key = property (_pks)
+    foreign_keys = property(_fks)
+    
     def append_whereclause(self, whereclause):
         self._append_condition('whereclause', whereclause)
     def append_having(self, having):