From: Mike Bayer Date: Thu, 19 Jan 2006 01:43:26 +0000 (+0000) Subject: implemented better hash_key on select allowing proper comparisons, implemented X-Git-Tag: rel_0_1_0~122 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c9e7e698e60d9d15e113a0b936ea630bcda5443a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git implemented better hash_key on select allowing proper comparisons, implemented hash_key on all clause objects added hash_key test to select util gets extra threadlocal functions and the recursionstack object --- diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 782eef02ff..e92246a404 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -601,6 +601,8 @@ class ClauseList(ClauseElement): if c is None: continue self.append(c) self.parens = kwargs.get('parens', False) + def hash_key(self): + return string.join([c.hash_key() for c in self.clauses], ",") def copy_container(self): clauses = [clause.copy_container() for clause in self.clauses] return ClauseList(parens=self.parens, *clauses) @@ -674,7 +676,7 @@ class Function(ClauseList, CompareMixin): self.name = name self.type = kwargs.get('type', sqltypes.NULLTYPE) ClauseList.__init__(self, parens=True, *clauses) - key = property(lambda self:self.label or self.name) + key = property(lambda self:self.name) def append(self, clause): if _is_literal(clause): if clause is None: @@ -703,6 +705,8 @@ class Function(ClauseList, CompareMixin): return self def select(self): return select([self]) + def hash_key(self): + return self.name + "(" + string.join([c.hash_key() for c in self.clauses], ", ") + ")" class BinaryClause(ClauseElement, CompareMixin): """represents two clauses with an operator in between""" @@ -717,7 +721,7 @@ class BinaryClause(ClauseElement, CompareMixin): def _get_from_objects(self): return self.left._get_from_objects() + self.right._get_from_objects() def hash_key(self): - return self.left.hash_key() + self.operator + self.right.hash_key() + return self.left.hash_key() + (self.operator or " ") + self.right.hash_key() def accept_visitor(self, visitor): self.left.accept_visitor(visitor) self.right.accept_visitor(visitor) @@ -868,6 +872,8 @@ class Label(ColumnElement): cc = ColumnClause(self.name, selectable) selectable.c[self.name] = cc return cc + def hash_key(self): + return "Label(%s, %s)" % (self.name, self.obj.hash_key()) class ColumnClause(ColumnElement): """represents a textual column clause in a SQL statement. allows the creation @@ -887,7 +893,7 @@ class ColumnClause(ColumnElement): def hash_key(self): if self.table is not None: - return "ColumnClause(%s, %s)" % (self.text, self.table.hash_key()) + return "ColumnClause(%s, %s)" % (self.text, util.hash_key(self.table)) else: return "ColumnClause(%s)" % self.text @@ -1086,7 +1092,11 @@ class CompoundSelect(SelectBaseMixin, FromClause): group_by = kwargs.get('group_by', None) if group_by: self.group_by(*group_by) - + def hash_key(self): + return "CompoundSelect(%s)" % string.join( + [util.hash_key(s) for s in self.selects] + + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['use_labels', 'keyword']], + ",") def _exportable_columns(self): return self.selects[0].columns def _proxy_column(self, column): @@ -1203,9 +1213,26 @@ class Select(SelectBaseMixin, FromClause): else: setattr(self, attribute, condition) - def hash_key(self): - return "Select(%d)" % (id(self)) + _hash_recursion = util.RecursionStack() + def hash_key(self): + # selects call alot of stuff so we do some "recursion checking" + # to eliminate loops + if Select._hash_recursion.push(self): + return "recursive_select()" + try: + return "Select(%s)" % string.join( + [ + "columns=" + repr([util.hash_key(c) for c in self._raw_columns]), + "where=" + util.hash_key(self.whereclause), + "from=" + repr([util.hash_key(f) for f in self.froms]), + "having=" + util.hash_key(self.having), + "clauses=" + repr([util.hash_key(c) for c in self.clauses]) + ] + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['use_labels', 'distinct', 'limit', 'offset']], "," + ) + finally: + Select._hash_recursion.pop(self) + def clear_from(self, id): self.append_from(FromClause(from_name = None, from_key = id)) @@ -1256,6 +1283,9 @@ class Select(SelectBaseMixin, FromClause): class UpdateBase(ClauseElement): """forms the base for INSERT, UPDATE, and DELETE statements.""" + def hash_key(self): + return str(id(self)) + def _process_colparams(self, parameters): """receives the "values" of an INSERT or UPDATE statement and constructs appropriate ind parameters.""" diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 7f417ba7e6..8592dd98d2 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -18,6 +18,16 @@ def to_list(x): def generic_repr(obj, exclude=None): L = ['%s=%s' % (a, repr(getattr(obj, a))) for a in dir(obj) if not callable(getattr(obj, a)) and not a.startswith('_') and (exclude is None or not exclude.has_key(a))] return '%s(%s)' % (obj.__class__.__name__, ','.join(L)) + +def hash_key(obj): + if obj is None: + return 'None' + elif isinstance(obj, list): + return repr([hash_key(o) for o in obj]) + elif hasattr(obj, 'hash_key'): + return obj.hash_key() + else: + return repr(obj) class OrderedProperties(object): """an object that maintains the order in which attributes are set upon it. @@ -49,7 +59,30 @@ class OrderedProperties(object): self.__dict__[key] = object - +class RecursionStack(object): + """a thread-local stack used to detect recursive object traversals.""" + def __init__(self): + self.stacks = {} + def _get_stack(self): + try: + stack = self.stacks[thread.get_ident()] + except KeyError: + stack = {} + self.stacks[thread.get_ident()] = stack + return stack + def push(self, obj): + s = self._get_stack() + if s.has_key(obj): + return True + else: + s[obj] = True + return False + def pop(self, obj): + stack = self._get_stack() + del stack[obj] + if len(stack) == 0: + del self.stacks[thread.get_ident()] + class OrderedDict(dict): """A Dictionary that keeps its own internal ordering""" def __init__(self, values = None): @@ -110,6 +143,13 @@ class ThreadLocal(object): def __init__(self, raiseerror = True): self.__dict__['_tdict'] = {} self.__dict__['_raiseerror'] = raiseerror + def __hasattr__(self, key): + return self._tdict.has_key("%d_%s" % (thread.get_ident(), key)) + def __delattr__(self, key): + try: + del self._tdict["%d_%s" % (thread.get_ident(), key)] + except KeyError: + raise AttributeError(key) def __getattr__(self, key): try: return self._tdict["%d_%s" % (thread.get_ident(), key)] @@ -121,6 +161,7 @@ class ThreadLocal(object): def __setattr__(self, key, value): self._tdict["%d_%s" % (thread.get_ident(), key)] = value + class HashSet(object): """implements a Set.""" def __init__(self, iter = None, ordered = False): diff --git a/test/select.py b/test/select.py index 1fa2fd456b..ca0eb0eac1 100644 --- a/test/select.py +++ b/test/select.py @@ -55,7 +55,8 @@ addresses = Table('addresses', db, class SQLTest(PersistTest): def runtest(self, clause, result, engine = None, params = None, checkparams = None): c = clause.compile(engine, params) - self.echo("\n" + str(c) + repr(c.get_params())) + self.echo("\nSQL String:\n" + str(c) + repr(c.get_params())) + self.echo("\nHash Key:\n" + clause.hash_key()) cc = re.sub(r'\n', '', str(c)) self.assert_(cc == result, str(c) + "\n does not match \n" + result) if checkparams is not None: