From 4f8f7ecd9ec2b5f6f57712548241acb3bf71f12c Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 27 Jan 2006 00:33:15 +0000 Subject: [PATCH] refactoring to allow column.label() to work in selects, etc. fixed superfluous codeline in ForeignKey --- lib/sqlalchemy/schema.py | 5 +- lib/sqlalchemy/sql.py | 137 +++++++++++++++++++++------------------ test/selectable.py | 4 ++ 3 files changed, 80 insertions(+), 66 deletions(-) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index c55d9034a0..6679c8b057 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -367,10 +367,7 @@ class ForeignKey(SchemaItem): def references(self, table): """returns True if the given table is referenced by this ForeignKey.""" - try: - return table._get_col_by_original(self.column) is not None - except: - x = self._init_column() + return table._get_col_by_original(self.column) is not None def _init_column(self): # ForeignKey inits its remote column as late as possible, so tables can diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index c37a8d620f..c3048de294 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -121,12 +121,12 @@ def or_(*clauses): def not_(clause): """returns a negation of the given clause, i.e. NOT(clause). the ~ operator can be used as well.""" clause.parens=True - return BinaryClause(TextClause("NOT"), clause, None) + return BooleanExpression(TextClause("NOT"), clause, None) def exists(*args, **params): s = select(*args, **params) - return BinaryClause(TextClause("EXISTS"), s, None) + return BooleanExpression(TextClause("EXISTS"), s, None) def union(*selects, **params): return _compound_select('UNION', *selects, **params) @@ -441,15 +441,15 @@ class CompareMixin(object): return lambda other: self._compare(operator, other) # and here come the math operators: def __add__(self, other): - return self._compare('+', other) + return self._operate('+', other) def __sub__(self, other): - return self._compare('-', other) + return self._operate('-', other) def __mul__(self, other): - return self._compare('*', other) + return self._operate('*', other) def __div__(self, other): - return self._compare('/', other) + return self._operate('/', other) def __truediv__(self, other): - return self._compare('/', other) + return self._operate('/', other) def _bind_param(self, obj): return BindParamClause('literal', obj, shortname=None, type=self.type) def _compare(self, operator, obj): @@ -457,12 +457,24 @@ class CompareMixin(object): if obj is None: if operator != '=': raise "Only '=' operator can be used with NULL" - return BinaryClause(self, null(), 'IS') + return BooleanExpression(self, null(), 'IS') else: obj = self._bind_param(obj) - return BinaryClause(self, obj, operator, type=obj.type) - + return BooleanExpression(self._compare_self(), obj, operator, type=self._compare_type(obj)) + def _operate(self, operator, obj): + if _is_literal(obj): + obj = self._bind_param(obj) + return BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj)) + def _compare_self(self): + """allows ColumnImpl to return its Column object for usage in ClauseElements, all others to + just return self""" + return self + def _compare_type(self, obj): + """allows subclasses to override the type used in constructing BinaryClause objects. Default return + value is the type of the given object.""" + return obj.type + class Selectable(ClauseElement): """represents a column list-holding object.""" @@ -482,11 +494,27 @@ class Selectable(ClauseElement): class ColumnElement(Selectable, CompareMixin): - """represents a column element within the list of a Selectable's columns.""" + """represents a column element within the list of a Selectable's columns. Provides + default implementations for the things a "column" needs, including a "primary_key" flag, + a "foreign_key" accessor, an "original" accessor which represents the ultimate column + underlying a string of labeled/select-wrapped columns, and "columns" which returns a list + of the single column, providing the same list-based interface as a FromClause.""" primary_key = property(lambda self:getattr(self, '_primary_key', False)) foreign_key = property(lambda self:getattr(self, '_foreign_key', False)) original = property(lambda self:getattr(self, '_original', self)) columns = property(lambda self:[self]) + def _make_proxy(self, selectable, name=None): + """creates a new ColumnElement representing this ColumnElement as it appears in the select list + of an enclosing selectable. The default implementation returns a ColumnClause if a name is given, + else just returns self. This has various mechanics with schema.Column and sql.Label so that + Column objects as well as non-column objects like Function and BinaryClause can both appear in the + select list of an enclosing selectable.""" + if name is not None: + co = ColumnClause(name, selectable) + selectable.columns[name]= co + return co + else: + return self class FromClause(Selectable): """represents an element that can be used within the FROM clause of a SELECT statement.""" @@ -520,7 +548,7 @@ class FromClause(Selectable): (i.e. an "original" column), return the Column object from this Selectable which corresponds to that original Column, or None if this Selectable does not contain the column.""" - return self.original_columns[column.original] + return self.original_columns.get(column.original, None) def _get_exported_attribute(self, name): try: return getattr(self, name) @@ -706,7 +734,7 @@ class CompoundClause(ClauseList): else: return False -class Function(ClauseList, CompareMixin): +class Function(ClauseList, ColumnElement): """describes a SQL function. extends ClauseList to provide comparison operators.""" def __init__(self, name, *clauses, **kwargs): self.name = name @@ -722,29 +750,19 @@ class Function(ClauseList, CompareMixin): self.clauses.append(clause) def copy_container(self): clauses = [clause.copy_container() for clause in self.clauses] - return Function(self.name, label=self.label, type=self.type, *clauses) + return Function(self.name, type=self.type, *clauses) def accept_visitor(self, visitor): for c in self.clauses: c.accept_visitor(visitor) visitor.visit_function(self) - def _compare(self, operator, obj): - if _is_literal(obj): - if obj is None: - if operator != '=': - raise "Only '=' operator can be used with NULL" - return BinaryClause(self, null(), 'IS') - else: - obj = BindParamClause(self.name, obj, shortname=self.name, type=self.type) - - return BinaryClause(self, obj, operator) - def _make_proxy(self, selectable, name = None): - return self + def _bind_param(self, obj): + return BindParamClause(self.name, obj, shortname=self.name, type=self.type) def select(self): return select([self]) def hash_key(self): return self.name + "(" + string.join([c.hash_key() for c in self.clauses], ", ") + ")" -class BinaryClause(ClauseElement, CompareMixin): +class BinaryClause(ClauseElement): """represents two clauses with an operator in between""" def __init__(self, left, right, operator, type=None): self.left = left @@ -772,6 +790,16 @@ class BinaryClause(ClauseElement, CompareMixin): isinstance(other, BinaryClause) and self.operator == other.operator and self.left.compare(other.left) and self.right.compare(other.right) ) + +class BooleanExpression(BinaryClause): + """represents a boolean expression, which is only useable in WHERE criterion.""" + pass +class BinaryExpression(BinaryClause, ColumnElement): + """represents a binary expression, which can be in a WHERE criterion or in the column list + of a SELECT. By adding "ColumnElement" to its inherited list, it becomes a Selectable + unit which can be placed in the column list of a SELECT.""" + pass + class Join(FromClause): def __init__(self, left, right, onclause=None, isouter = False): @@ -873,7 +901,7 @@ class Alias(FromClause): return self.selectable.columns def hash_key(self): - return "Alias(%s, %s)" % (repr(self.selectable.hash_key()), repr(self.name)) + return "Alias(%s, %s)" % (self.selectable.hash_key(), repr(self.name)) def accept_visitor(self, visitor): self.selectable.accept_visitor(visitor) @@ -897,17 +925,15 @@ class Label(ColumnElement): obj.parens=True key = property(lambda s: s.name) _label = property(lambda s: s.name) + original = property(lambda s:s.obj.original) def accept_visitor(self, visitor): self.obj.accept_visitor(visitor) visitor.visit_label(self) def _get_from_objects(self): return self.obj._get_from_objects() def _make_proxy(self, selectable, name = None): - # TODO: this make_proxy needs foreign_key, primary_key support - # from its underlying column, if any - cc = ColumnClause(self.name, selectable) - selectable.c[self.name] = cc - return cc + return self.obj._make_proxy(selectable, name=self.name) + def hash_key(self): return "Label(%s, %s)" % (self.name, self.obj.hash_key()) @@ -936,19 +962,11 @@ class ColumnClause(ColumnElement): def _get_from_objects(self): return [] - def _compare(self, operator, obj): - if _is_literal(obj): - if obj is None: - if operator != '=': - raise "Only '=' operator can be used with NULL" - return BinaryClause(self, null(), 'IS') - elif self.table.name is None: - obj = BindParamClause(self.text, obj, shortname=self.text, type=self.type) - else: - obj = BindParamClause(self.table.name + "_" + self.text, obj, shortname = self.text, type=self.type) - - return BinaryClause(self, obj, operator) - + def _bind_param(self, obj): + if self.table.name is None: + return BindParamClause(self.text, obj, shortname=self.text, type=self.type) + else: + return BindParamClause(self.table.name + "_" + self.text, obj, shortname = self.text, type=self.type) def _make_proxy(self, selectable, name = None): c = ColumnClause(self.text or name, selectable) selectable.columns[c.key] = c @@ -992,18 +1010,13 @@ class ColumnImpl(ColumnElement): return BindParamClause(self.name, obj, shortname = self.name, type = self.column.type) else: return BindParamClause(self.column.table.name + "_" + self.name, obj, shortname = self.name, type = self.column.type) - - def _compare(self, operator, obj): - if _is_literal(obj): - if obj is None: - if operator != '=': - raise "Only '=' operator can be used with NULL" - return BinaryClause(self.column, null(), 'IS') - else: - obj = self._bind_param(obj) - - return BinaryClause(self.column, obj, operator, type=self.column.type) - + def _compare_self(self): + """allows ColumnImpl to return its Column object for usage in ClauseElements, all others to + just return self""" + return self.column + def _compare_type(self, obj): + return self.column.type + def compile(self, engine = None, parameters = None, typemap=None): if engine is None: engine = self.engine @@ -1226,7 +1239,7 @@ class Select(SelectBaseMixin, FromClause): for f in column._get_from_objects(): f.accept_visitor(self._correlator) column._process_from_dict(self._froms, False) - + def _exportable_columns(self): return self._raw_columns def _proxy_column(self, column): @@ -1259,11 +1272,11 @@ class Select(SelectBaseMixin, FromClause): try: return "Select(%s)" % string.join( [ - "columns=" + repr([util.hash_key(c) for c in self._raw_columns]), + "columns=" + string.join([util.hash_key(c) for c in self._raw_columns],','), "where=" + util.hash_key(self.whereclause), - "from=" + repr([util.hash_key(f) for f in self.froms]), + "from=" + string.join([util.hash_key(f) for f in self.froms],','), "having=" + util.hash_key(self.having), - "clauses=" + repr([util.hash_key(c) for c in self.clauses]) + "clauses=" + string.join([util.hash_key(c) for c in self.clauses], ',') ] + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['use_labels', 'distinct', 'limit', 'offset']], "," ) finally: diff --git a/test/selectable.py b/test/selectable.py index 37b1f28eca..aefef25e06 100755 --- a/test/selectable.py +++ b/test/selectable.py @@ -74,7 +74,11 @@ class SelectableTest(testbase.AssertMixin): a = select([table.c.col1.label('acol1'), table.c.col2.label('acol2'), table.c.col3.label('acol3')]) print str(a) print [c for c in a.columns] + print str(a.select()) j = join(a, table2) + criterion = a.c.acol1 == table2.c.col2 + print str(j) + self.assert_(criterion.compare(j.onclause)) def testselectaliaslabels(self): a = table2.select(use_labels=True).alias('a') -- 2.47.2