From 113398519a74e58c09cda2c130a34f7ae4b0c417 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 8 Jan 2006 01:26:47 +0000 Subject: [PATCH] improvements to relational algrebra of Alias, Select, Join objects, so that they all report their column lists, primary key, foreign key lists consistently and so that ForeignKey objects can line up tables against relational objects --- lib/sqlalchemy/schema.py | 24 +++-- lib/sqlalchemy/sql.py | 205 ++++++++++++++++++--------------------- 2 files changed, 110 insertions(+), 119 deletions(-) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index c5054cd500..1d5b504e2e 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -281,7 +281,11 @@ class Column(SchemaItem): c._orig = self.original if not c.hidden: selectable.columns[c.key] = c + if self.primary_key: + selectable.primary_key.append(c) c._impl = self.engine.columnimpl(c) + if fk is not None: + c._init_items(fk) return c def accept_visitor(self, visitor): @@ -325,17 +329,17 @@ class ForeignKey(SchemaItem): if isinstance(self._colspec, str): return ForeignKey(self._colspec) else: - return ForeignKey("%s.%s" % (self._colspec.table.name, self._colspec.column.key)) + if self._colspec.table.schema is not None: + return ForeignKey("%s.%s.%s" % (self._colspec.table.schema, self._colspec.table.name, self._colspec.column.key)) + else: + return ForeignKey("%s.%s" % (self._colspec.table.name, self._colspec.column.key)) def references(self, table): """returns True if the given table is referenced by this ForeignKey.""" - return ( - # simple test - self.column.table is table - or - # test for an indirect relation via a Selectable - table._get_col_by_original(self.column) is not None - ) + try: + return table._get_col_by_original(self.column) is not None + except: + x = self._init_column() def _init_column(self): # ForeignKey inits its remote column as late as possible, so tables can @@ -347,8 +351,8 @@ class ForeignKey(SchemaItem): raise ValueError("Invalid foreign key column specification: " + self._colspec) if m.group(3) is None: (tname, colname) = m.group(1, 2) - # default to containing table's schema - schema = self.parent.table.schema + # use default schema + schema = None else: (schema,tname,colname) = m.group(1,2,3) table = Table(tname, self.parent.engine, mustexist=True, schema=schema) diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 2a1b6829f8..e08c644f47 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -340,7 +340,6 @@ class ClauseElement(object): return None engine = property(lambda s: s._find_engine()) - def compile(self, engine = None, parameters = None, typemap=None): """compiles this SQL expression using its underlying SQLEngine to produce @@ -449,16 +448,8 @@ class CompareMixin(object): class Selectable(ClauseElement): """represents a column list-holding object.""" - def _get_columns(self): - try: - return self._columns - except AttributeError: - return [self] - columns = property(lambda s: s._get_columns()) - c = property(lambda self: self.columns) - def accept_visitor(self, visitor): - raise NotImplementedError() + raise NotImplementedError(repr(self)) def is_selectable(self): return True @@ -466,22 +457,17 @@ class Selectable(ClauseElement): def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) - def _get_col_by_original(self, column): - """given a column which is a schema.Column object attached to a schema.Table object - (i.e. an "original" column), return the Column object from this - Selectable which corresponds to that original Column, or None if this Selectable - does not contain the column.""" - raise NotImplementedError() - def _group_parenthesized(self): """indicates if this Selectable requires parenthesis when grouped into a compound statement""" return True + class ColumnElement(Selectable, CompareMixin): """represents a column element within the list of a Selectable's columns.""" - primary_key = property(lambda s:False) - original = property(lambda self:self) + primary_key = property(lambda s:getattr(self, '_primary_key', False)) + original = property(lambda self:getattr(self, '_original', self)) + columns = property(lambda self:[self]) class FromClause(Selectable): """represents an element that can be used within the FROM clause of a SELECT statement.""" @@ -508,7 +494,42 @@ class FromClause(Selectable): return Join(self, right, isouter = True, *args, **kwargs) def alias(self, name=None): return Alias(self, name) - + def _get_col_by_original(self, column): + """given a column which is a schema.Column object attached to a schema.Table object + (i.e. an "original" column), return the Column object from this + Selectable which corresponds to that original Column, or None if this Selectable + does not contain the column.""" + return self.original_columns[column.original] + def _get_exported_attribute(self, name): + try: + return getattr(self, name) + except AttributeError: + self._export_columns() + return getattr(self, name) + columns = property(lambda s:s._get_exported_attribute('_columns')) + c = property(lambda s:s._get_exported_attribute('_columns')) + primary_key = property(lambda s:s._get_exported_attribute('_primary_key')) + foreign_keys = property(lambda s:s._get_exported_attribute('_foreign_keys')) + original_columns = property(lambda s:s._get_exported_attribute('_orig_cols')) + + def _export_columns(self): + if hasattr(self, '_columns'): + # TODO: put a mutex here ? this is a key place for threading probs + return + self._columns = util.OrderedProperties() + self._primary_key = [] + self._foreign_keys = [] + self._orig_cols = {} + export = self._exportable_columns() + for column in export: + if column.is_selectable(): + for co in column.columns: + cp = self._proxy_column(co) + self._orig_cols[co.original] = cp + def _exportable_columns(self): + raise NotImplementedError(repr(self)) + def _proxy_column(self, column): + return column._make_proxy(self) class BindParamClause(ClauseElement, CompareMixin): """represents a bind parameter. public constructor is the bindparam() function.""" @@ -708,16 +729,10 @@ class BinaryClause(ClauseElement, CompareMixin): ) class Join(FromClause): - # TODO: put "using" + "natural" concepts in here and make "onclause" optional - def __init__(self, left, right, onclause=None, isouter = False, allcols = True): + def __init__(self, left, right, onclause=None, isouter = False): self.left = left self.right = right self.id = self.left.id + "_" + self.right.id - self.allcols = allcols - if allcols: - self._columns = [c for c in self.left.columns] + [c for c in self.right.columns] - else: - self._columns = self.right.columns # TODO: if no onclause, do NATURAL JOIN if onclause is None: @@ -726,9 +741,15 @@ class Join(FromClause): self.onclause = onclause self.isouter = isouter self.oid_column = self.left.oid_column - - primary_key = property (lambda self: [c for c in self.left.columns if c.primary_key] + [c for c in self.right.columns if c.primary_key]) - + def _exportable_columns(self): + return [c for c in self.left.columns] + [c for c in self.right.columns] + def _proxy_column(self, column): + self._columns[column.table.name + "_" + column.key] = column + if column.primary_key: + self._primary_key.append(column) + if column.foreign_key: + self._foreign_keys.append(column) + return column def _match_primaries(self, primary, secondary): crit = [] for fk in secondary.foreign_keys: @@ -752,13 +773,6 @@ class Join(FromClause): statement""" return True - def _get_col_by_original(self, column): - for c in self.columns: - if c.original is column.original: - return c - else: - return None - def hash_key(self): return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter)) @@ -792,8 +806,9 @@ class Join(FromClause): class Alias(FromClause): def __init__(self, selectable, alias = None): + while isinstance(selectable, Alias): + selectable = selectable.selectable self.selectable = selectable - self._columns = util.OrderedProperties() if alias is None: n = getattr(selectable, 'name') if n is None: @@ -806,22 +821,13 @@ class Alias(FromClause): self.oid_column = self.selectable.oid_column._make_proxy(self) else: self.oid_column = None - for co in selectable.columns: - co._make_proxy(self) - primary_key = property (lambda self: [c for c in self.columns if c.primary_key]) - foreign_keys = property(lambda s:s.selectable.foreign_keys) + def _exportable_columns(self): + return self.selectable.columns def hash_key(self): return "Alias(%s, %s)" % (repr(self.selectable.hash_key()), repr(self.name)) - def _get_col_by_original(self, column): - c = self.columns.get(column.key, None) - if c is not None and c.original is column.original: - return c - else: - return None - def accept_visitor(self, visitor): self.selectable.accept_visitor(visitor) visitor.visit_alias(self) @@ -850,7 +856,9 @@ class Label(ColumnElement): def _get_from_objects(self): return self.obj._get_from_objects() def _make_proxy(self, selectable, name = None): - cc = ColumnClause(self.name) + # TODO: this make_proxy needs foreign_key, primary_key support + # from its underlying column, if any + cc = ColumnClause(self.name, selectable) selectable.c[self.name] = cc return cc @@ -903,7 +911,6 @@ class ColumnImpl(ColumnElement): def __init__(self, column): self.column = column self.name = column.name - self._columns = [self.column] if column.table.name: self._label = column.table.name + "_" + self.column.name @@ -912,6 +919,8 @@ class ColumnImpl(ColumnElement): engine = property(lambda s: s.column.engine) default_label = property(lambda s:s._label) + original = property(lambda self:self.column) + columns = property(lambda self:[self.column]) def label(self, name): return Label(name, self.column) @@ -923,12 +932,6 @@ class ColumnImpl(ColumnElement): """compares this ColumnImpl's column to the other given Column""" return self.column is other - def _get_col_by_original(self, column): - if self.column.original is column.original: - return self.column - else: - return None - def _group_parenthesized(self): return False @@ -940,7 +943,7 @@ class ColumnImpl(ColumnElement): return BindParamClause(self.name, obj, shortname = self.name, type = self.column.type) else: return BindParamClause(self.column.table.name + "_" + self.name, obj, shortname = self.name, type = self.column.type) - + def _compare(self, operator, obj): if _is_literal(obj): if obj is None: @@ -952,6 +955,13 @@ class ColumnImpl(ColumnElement): return BinaryClause(self.column, obj, operator, type=self.column.type) + def compile(self, engine = None, parameters = None, typemap=None): + if engine is None: + engine = self.engine + if engine is None: + raise "no SQLEngine could be located within this ClauseElement." + return engine.compile(self.column, parameters=parameters, typemap=typemap) + class TableImpl(FromClause): """attached to a schema.Table to provide it with a Selectable interface as well as other functions @@ -971,21 +981,25 @@ class TableImpl(FromClause): self._oid_column = None return self._oid_column + def _orig_columns(self): + try: + return self._orig_cols + except AttributeError: + self._orig_cols= {} + for c in self.columns: + self._orig_cols[c.original] = c + return self._orig_cols + oid_column = property(_oid_col) engine = property(lambda s: s.table.engine) columns = property(lambda self: self.table.columns) primary_key = property(lambda self:self.table.primary_key) - - def _get_col_by_original(self, column): - try: - col = self.columns[column.key] - except KeyError: - return None - if col.original is column.original: - return col - else: - return None + foreign_keys = property(lambda self:self.table.foreign_keys) + original_columns = property(_orig_columns) + def _exportable_columns(self): + raise NotImplementedError() + def _group_parenthesized(self): return False @@ -1065,9 +1079,14 @@ class CompoundSelect(SelectBaseMixin, FromClause): if group_by: self.group_by(*group_by) - primary_key = property(lambda s:s.selects[0].primary_key) - foreign_keys = property(lambda s:s.selects[0].foreign_keys) - columns = property(lambda s:s.selects[0].columns) + def _exportable_columns(self): + return self.selects[0].columns + def _proxy_column(self, column): + self._columns[column.key] = column + if column.primary_key: + self._primary_key.append(column) + if column.foreign_key: + self._foreign_keys.append(column) def accept_visitor(self, visitor): for tup in self.clauses: tup[1].accept_visitor(visitor) @@ -1086,7 +1105,6 @@ class Select(SelectBaseMixin, FromClause): """represents a SELECT statement, with appendable clauses, as well as the ability to execute itself and return a result set.""" def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None, limit=None, offset=None): - self._columns = util.OrderedProperties() self._froms = util.OrderedDict() self.use_labels = use_labels self.id = "Select(%d)" % id(self) @@ -1155,45 +1173,14 @@ class Select(SelectBaseMixin, FromClause): f.accept_visitor(self._correlator) column._process_from_dict(self._froms, False) - if column.is_selectable(): - # if its a column unit, add it to our exported - # list of columns. this is where "columns" - # attribute of the select object gets populated. - # notice we are overriding the names of the column - # with either its label or its key, since one or the other - # is used when selecting from a select statement (i.e. a subquery) - for co in column.columns: - if self.use_labels: - co._make_proxy(self, name=co._label) - else: - co._make_proxy(self, name=co.key) - - def _get_col_by_original(self, column): + def _exportable_columns(self): + return self._raw_columns + def _proxy_column(self, column): if self.use_labels: - c = self.columns.get(column._label,None) - else: - c = self.columns.get(column.key,None) - if c is not None and c.original is column.original: - return c + return column._make_proxy(self, name=column._label) else: - return None - - def _pks(self): - ret = {} - for from_obj in self._get_froms(): - for c in from_obj.c: - if c.primary_key: - ret[c] = c - return ret.keys() - def _fks(self): - ret = [] - for from_obj in self._get_froms(): - for fk in from_obj.foreign_keys: - ret.append(fk) - return ret - primary_key = property (_pks) - foreign_keys = property(_fks) - + return column._make_proxy(self, name=column.key) + def append_whereclause(self, whereclause): self._append_condition('whereclause', whereclause) def append_having(self, having): -- 2.47.2