]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implemented better hash_key on select allowing proper comparisons, implemented
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Jan 2006 01:43:26 +0000 (01:43 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Jan 2006 01:43:26 +0000 (01:43 +0000)
hash_key on all clause objects
added hash_key test to select
util gets extra threadlocal functions and the recursionstack object

lib/sqlalchemy/sql.py
lib/sqlalchemy/util.py
test/select.py

index 782eef02ff3133546150e26dc7c6ec173ae3e69e..e92246a404a5127312561886d44b0098e7b992a7 100644 (file)
@@ -601,6 +601,8 @@ class ClauseList(ClauseElement):
             if c is None: continue
             self.append(c)
         self.parens = kwargs.get('parens', False)
+    def hash_key(self):
+        return string.join([c.hash_key() for c in self.clauses], ",")
     def copy_container(self):
         clauses = [clause.copy_container() for clause in self.clauses]
         return ClauseList(parens=self.parens, *clauses)
@@ -674,7 +676,7 @@ class Function(ClauseList, CompareMixin):
         self.name = name
         self.type = kwargs.get('type', sqltypes.NULLTYPE)
         ClauseList.__init__(self, parens=True, *clauses)
-    key = property(lambda self:self.label or self.name)
+    key = property(lambda self:self.name)
     def append(self, clause):
         if _is_literal(clause):
             if clause is None:
@@ -703,6 +705,8 @@ class Function(ClauseList, CompareMixin):
         return self
     def select(self):
         return select([self])
+    def hash_key(self):
+        return self.name + "(" + string.join([c.hash_key() for c in self.clauses], ", ") + ")"
             
 class BinaryClause(ClauseElement, CompareMixin):
     """represents two clauses with an operator in between"""
@@ -717,7 +721,7 @@ class BinaryClause(ClauseElement, CompareMixin):
     def _get_from_objects(self):
         return self.left._get_from_objects() + self.right._get_from_objects()
     def hash_key(self):
-        return self.left.hash_key() + self.operator + self.right.hash_key()
+        return self.left.hash_key() + (self.operator or " ") + self.right.hash_key()
     def accept_visitor(self, visitor):
         self.left.accept_visitor(visitor)
         self.right.accept_visitor(visitor)
@@ -868,6 +872,8 @@ class Label(ColumnElement):
         cc = ColumnClause(self.name, selectable)
         selectable.c[self.name] = cc
         return cc
+    def hash_key(self):
+        return "Label(%s, %s)" % (self.name, self.obj.hash_key())
      
 class ColumnClause(ColumnElement):
     """represents a textual column clause in a SQL statement. allows the creation
@@ -887,7 +893,7 @@ class ColumnClause(ColumnElement):
 
     def hash_key(self):
         if self.table is not None:
-            return "ColumnClause(%s, %s)" % (self.text, self.table.hash_key())
+            return "ColumnClause(%s, %s)" % (self.text, util.hash_key(self.table))
         else:
             return "ColumnClause(%s)" % self.text
 
@@ -1086,7 +1092,11 @@ class CompoundSelect(SelectBaseMixin, FromClause):
         group_by = kwargs.get('group_by', None)
         if group_by:
             self.group_by(*group_by)
-
+    def hash_key(self):
+        return "CompoundSelect(%s)" % string.join(
+            [util.hash_key(s) for s in self.selects] + 
+            ["%s=%s" % (k, repr(getattr(self, k))) for k in ['use_labels', 'keyword']],
+            ",")
     def _exportable_columns(self):
         return self.selects[0].columns
     def _proxy_column(self, column):
@@ -1203,9 +1213,26 @@ class Select(SelectBaseMixin, FromClause):
         else:
             setattr(self, attribute, condition)
 
-    def hash_key(self):
-        return "Select(%d)" % (id(self))
+    _hash_recursion = util.RecursionStack()
     
+    def hash_key(self):
+        # selects call alot of stuff so we do some "recursion checking"
+        # to eliminate loops
+        if Select._hash_recursion.push(self):
+            return "recursive_select()"
+        try:
+            return "Select(%s)" % string.join(
+                [
+                    "columns=" + repr([util.hash_key(c) for c in self._raw_columns]),
+                    "where=" + util.hash_key(self.whereclause),
+                    "from=" + repr([util.hash_key(f) for f in self.froms]),
+                    "having=" + util.hash_key(self.having),
+                    "clauses=" + repr([util.hash_key(c) for c in self.clauses])
+                ] + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['use_labels', 'distinct', 'limit', 'offset']], ","
+            ) 
+        finally:
+            Select._hash_recursion.pop(self)
+        
     def clear_from(self, id):
         self.append_from(FromClause(from_name = None, from_key = id))
         
@@ -1256,6 +1283,9 @@ class Select(SelectBaseMixin, FromClause):
 class UpdateBase(ClauseElement):
     """forms the base for INSERT, UPDATE, and DELETE statements."""
     
+    def hash_key(self):
+        return str(id(self))
+        
     def _process_colparams(self, parameters):
         """receives the "values" of an INSERT or UPDATE statement and constructs
         appropriate ind parameters."""
index 7f417ba7e65320c98d7f365e3dfbb1df1f4635b0..8592dd98d2f6ed481a2bb638a497e71efa53caa8 100644 (file)
@@ -18,6 +18,16 @@ def to_list(x):
 def generic_repr(obj, exclude=None):
     L = ['%s=%s' % (a, repr(getattr(obj, a))) for a in dir(obj) if not callable(getattr(obj, a)) and not a.startswith('_') and (exclude is None or not exclude.has_key(a))]
     return '%s(%s)' % (obj.__class__.__name__, ','.join(L))
+
+def hash_key(obj):
+    if obj is None:
+        return 'None'
+    elif isinstance(obj, list):
+        return repr([hash_key(o) for o in obj])
+    elif hasattr(obj, 'hash_key'):
+        return obj.hash_key()
+    else:
+        return repr(obj)
         
 class OrderedProperties(object):
     """an object that maintains the order in which attributes are set upon it.
@@ -49,7 +59,30 @@ class OrderedProperties(object):
     
         self.__dict__[key] = object
     
-
+class RecursionStack(object):
+    """a thread-local stack used to detect recursive object traversals."""
+    def __init__(self):
+        self.stacks = {}
+    def _get_stack(self):
+        try:
+            stack = self.stacks[thread.get_ident()]
+        except KeyError:
+            stack = {}
+            self.stacks[thread.get_ident()] = stack
+        return stack
+    def push(self, obj):
+        s = self._get_stack()
+        if s.has_key(obj):
+            return True
+        else:
+            s[obj] = True
+            return False
+    def pop(self, obj):
+        stack = self._get_stack()
+        del stack[obj]
+        if len(stack) == 0:
+            del self.stacks[thread.get_ident()]
+        
 class OrderedDict(dict):
     """A Dictionary that keeps its own internal ordering"""
     def __init__(self, values = None):
@@ -110,6 +143,13 @@ class ThreadLocal(object):
     def __init__(self, raiseerror = True):
         self.__dict__['_tdict'] = {}
         self.__dict__['_raiseerror'] = raiseerror
+    def __hasattr__(self, key):
+        return self._tdict.has_key("%d_%s" % (thread.get_ident(), key))
+    def __delattr__(self, key):
+        try:
+            del self._tdict["%d_%s" % (thread.get_ident(), key)]
+        except KeyError:
+            raise AttributeError(key)
     def __getattr__(self, key):
         try:
             return self._tdict["%d_%s" % (thread.get_ident(), key)]
@@ -121,6 +161,7 @@ class ThreadLocal(object):
     def __setattr__(self, key, value):
         self._tdict["%d_%s" % (thread.get_ident(), key)] = value
 
+
 class HashSet(object):
     """implements a Set."""
     def __init__(self, iter = None, ordered = False):
index 1fa2fd456b1b6ba4c67fe3a913898f5fa4629d7c..ca0eb0eac19826eebff0309aebeed00912acc6ca 100644 (file)
@@ -55,7 +55,8 @@ addresses = Table('addresses', db,
 class SQLTest(PersistTest):
     def runtest(self, clause, result, engine = None, params = None, checkparams = None):
         c = clause.compile(engine, params)
-        self.echo("\n" + str(c) + repr(c.get_params()))
+        self.echo("\nSQL String:\n" + str(c) + repr(c.get_params()))
+        self.echo("\nHash Key:\n" + clause.hash_key())
         cc = re.sub(r'\n', '', str(c))
         self.assert_(cc == result, str(c) + "\n does not match \n" + result)
         if checkparams is not None: