From: Mike Bayer Date: Thu, 5 Jan 2006 05:44:10 +0000 (+0000) Subject: added compare function to the more basic expression objects X-Git-Tag: rel_0_1_0~167 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d5536dd30a5872410973de8599f0b41f5f4b2894;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git added compare function to the more basic expression objects adding priamry_key/foreign_keys to selects, alias etc to increase their useability for relating them to tables improved _get_col_by_original to double-check the column it finds --- diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 486fa304ea..2a1b6829f8 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -303,6 +303,13 @@ class ClauseElement(object): data.setdefault(f.id, f) if asfrom: data[self.id] = self + def compare(self, other): + """compares this ClauseElement to the given ClauseElement. + + Subclasses should override the default behavior, which is a straight + identity comparison.""" + return self is other + def accept_visitor(self, visitor): """accepts a ClauseVisitor and calls the appropriate visit_xxx method.""" raise NotImplementedError(repr(self)) @@ -521,6 +528,12 @@ class BindParamClause(ClauseElement, CompareMixin): return "BindParam(%s, %s, %s)" % (repr(self.key), repr(self.value), repr(self.shortname)) def typeprocess(self, value, engine): return self.type.convert_bind_param(value, engine) + def compare(self, other): + """compares this BindParamClause to the given clause. + + Since compare() is meant to compare statement syntax, this method + returns True if the two BindParamClauses have just the same type.""" + return isinstance(other, BindParamClause) and other.type.__class__ == self.type.__class__ class TextClause(ClauseElement): """represents literal a SQL text fragment. public constructor is the @@ -580,6 +593,17 @@ class ClauseList(ClauseElement): for c in self.clauses: f += c._get_from_objects() return f + def compare(self, other): + """compares this ClauseList to the given ClauseList, including + a comparison of all the clause items.""" + if isinstance(other, ClauseList) and len(self.clauses) == len(other.clauses): + for i in range(0, len(self.clauses)): + if not self.clauses[i].compare(other.clauses[i]): + return False + else: + return True + else: + return False class CompoundClause(ClauseList): """represents a list of clauses joined by an operator, such as AND or OR. @@ -606,8 +630,20 @@ class CompoundClause(ClauseList): return f def hash_key(self): return string.join([c.hash_key() for c in self.clauses], self.operator or " ") - - + def compare(self, other): + """compares this CompoundClause to the given item. + + In addition to the regular comparison, has the special case that it + returns True if this CompoundClause has only one item, and that + item matches the given item.""" + if not isinstance(other, CompoundClause): + if len(self.clauses) == 1: + return self.clauses[0].compare(other) + if ClauseList.compare(self, other): + return self.operator == other.operator + else: + return False + class Function(ClauseList, CompareMixin): """describes a SQL function. extends ClauseList to provide comparison operators.""" def __init__(self, name, *clauses, **kwargs): @@ -641,7 +677,6 @@ class Function(ClauseList, CompareMixin): return BinaryClause(self, obj, operator) def _make_proxy(self, selectable, name = None): return self - class BinaryClause(ClauseElement, CompareMixin): """represents two clauses with an operator in between""" @@ -665,7 +700,13 @@ class BinaryClause(ClauseElement, CompareMixin): c = self.left self.left = self.right self.right = c - + def compare(self, other): + """compares this BinaryClause against the given BinaryClause.""" + return ( + isinstance(other, BinaryClause) and self.operator == other.operator and + self.left.compare(other.left) and self.right.compare(other.right) + ) + class Join(FromClause): # TODO: put "using" + "natural" concepts in here and make "onclause" optional def __init__(self, left, right, onclause=None, isouter = False, allcols = True): @@ -713,7 +754,7 @@ class Join(FromClause): def _get_col_by_original(self, column): for c in self.columns: - if c.original is column: + if c.original is column.original: return c else: return None @@ -753,7 +794,6 @@ class Alias(FromClause): def __init__(self, selectable, alias = None): self.selectable = selectable self._columns = util.OrderedProperties() - self.foreign_keys = [] if alias is None: n = getattr(selectable, 'name') if n is None: @@ -770,12 +810,17 @@ class Alias(FromClause): co._make_proxy(self) primary_key = property (lambda self: [c for c in self.columns if c.primary_key]) - + foreign_keys = property(lambda s:s.selectable.foreign_keys) + def hash_key(self): return "Alias(%s, %s)" % (repr(self.selectable.hash_key()), repr(self.name)) def _get_col_by_original(self, column): - return self.columns.get(column.key, None) + c = self.columns.get(column.key, None) + if c is not None and c.original is column.original: + return c + else: + return None def accept_visitor(self, visitor): self.selectable.accept_visitor(visitor) @@ -874,8 +919,12 @@ class ColumnImpl(ColumnElement): def copy_container(self): return self.column + def compare(self, other): + """compares this ColumnImpl's column to the other given Column""" + return self.column is other + def _get_col_by_original(self, column): - if self.column.original is column: + if self.column.original is column.original: return self.column else: return None @@ -932,7 +981,7 @@ class TableImpl(FromClause): col = self.columns[column.key] except KeyError: return None - if col.original is column: + if col.original is column.original: return col else: return None @@ -1016,6 +1065,8 @@ class CompoundSelect(SelectBaseMixin, FromClause): if group_by: self.group_by(*group_by) + primary_key = property(lambda s:s.selects[0].primary_key) + foreign_keys = property(lambda s:s.selects[0].foreign_keys) columns = property(lambda s:s.selects[0].columns) def accept_visitor(self, visitor): for tup in self.clauses: @@ -1119,9 +1170,13 @@ class Select(SelectBaseMixin, FromClause): def _get_col_by_original(self, column): if self.use_labels: - return self.columns.get(column._label,None) + c = self.columns.get(column._label,None) else: - return self.columns.get(column.key,None) + c = self.columns.get(column.key,None) + if c is not None and c.original is column.original: + return c + else: + return None def _pks(self): ret = {} @@ -1130,8 +1185,15 @@ class Select(SelectBaseMixin, FromClause): if c.primary_key: ret[c] = c return ret.keys() - primary_key = property (lambda self: self._pks()) - + def _fks(self): + ret = [] + for from_obj in self._get_froms(): + for fk in from_obj.foreign_keys: + ret.append(fk) + return ret + primary_key = property (_pks) + foreign_keys = property(_fks) + def append_whereclause(self, whereclause): self._append_condition('whereclause', whereclause) def append_having(self, having):