From: Mike Bayer Date: Fri, 2 Dec 2005 08:49:45 +0000 (+0000) Subject: added functionality to map columns to their aliased versions. X-Git-Tag: rel_0_1_0~279 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9a205e891d9b63ee83590a20772f3ee22e713398;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git added functionality to map columns to their aliased versions. added support for specifying an alias in a relation. added a new relation flag 'selectalias' which causes eagerloader to use a local alias name for its target table, translating columns back to the original non-aliased name as result rows come in. --- diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 86527a1105..12b3d8e61a 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -51,10 +51,10 @@ def _relation_loader(mapper, secondary=None, primaryjoin=None, secondaryjoin=Non def _relation_mapper(class_, table=None, secondary=None, primaryjoin=None, secondaryjoin=None, - foreignkey=None, uselist=None, private=False, live=False, association=None, lazy=True, **kwargs): + foreignkey=None, uselist=None, private=False, live=False, association=None, lazy=True, selectalias=None, **kwargs): return _relation_loader(mapper(class_, table, **kwargs), secondary, primaryjoin, secondaryjoin, - foreignkey=foreignkey, uselist=uselist, private=private, live=live, association=association, lazy=lazy) + foreignkey=foreignkey, uselist=uselist, private=private, live=live, association=association, lazy=lazy, selectalias=selectalias) #def _relation_mapper(class_, table=None, secondary=None, # primaryjoin=None, secondaryjoin=None, foreignkey=None, @@ -444,7 +444,7 @@ class Mapper(object): def _getattrbycolumn(self, obj, column): try: - prop = self.columntoproperty[column] + prop = self.columntoproperty[column.original] except KeyError: try: prop = self.props[column.key] @@ -455,7 +455,7 @@ class Mapper(object): return prop[0].getattr(obj) def _setattrbycolumn(self, obj, column, value): - self.columntoproperty[column][0].setattr(obj, value) + self.columntoproperty[column.original][0].setattr(obj, value) def save_obj(self, objects, uow): @@ -700,7 +700,7 @@ class PropertyLoader(MapperProperty): """describes an object property that holds a single item or list of items that correspond to a related database table.""" - def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, live=False, isoption=False, association=None, **kwargs): + def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, live=False, isoption=False, association=None, selectalias=None, **kwargs): self.uselist = uselist self.argument = argument self.secondary = secondary @@ -711,6 +711,7 @@ class PropertyLoader(MapperProperty): self.live = live self.isoption = isoption self.association = association + self.selectalias = selectalias self._hash_key = "%s(%s, %s, %s, %s, %s, %s, %s)" % (self.__class__.__name__, hash_key(self.argument), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), hash_key(foreignkey), repr(uselist), repr(private)) def _copy(self): @@ -728,7 +729,7 @@ class PropertyLoader(MapperProperty): if self.association is not None: if isinstance(self.association, type): self.association = class_mapper(self.association) - + self.target = self.mapper.table self.key = key self.parent = parent @@ -812,16 +813,15 @@ class PropertyLoader(MapperProperty): def _match_primaries(self, primary, secondary): crit = [] for fk in secondary.foreign_keys: - if fk.column.table is primary: - crit.append(fk.column == fk.parent) + if fk.references(primary): + crit.append(primary.get_col_by_original(fk.column) == fk.parent) self.foreignkey = fk.parent for fk in primary.foreign_keys: - if fk.column.table is secondary: - crit.append(fk.column == fk.parent) + if fk.references(secondary): + crit.append(secondary.get_col_by_original(fk.column) == fk.parent) self.foreignkey = fk.parent - if len(crit) == 0: - raise "Cant find any foreign key relationships between '%s' (%s) and '%s' (%s)" % (primary.table.name, repr(primary.table), secondary.table.name, repr(secondary.table)) + raise "Cant find any foreign key relationships between '%s' (%s) and '%s' (%s)" % (primary.name, repr(primary), secondary.name, repr(secondary)) elif len(crit) == 1: return (crit[0]) else: @@ -1154,6 +1154,26 @@ class EagerLoader(PropertyLoader): if self.secondaryjoin is not None: [self.to_alias.append(f) for f in self.secondaryjoin._get_from_objects()] del self.to_alias[parent.primarytable] + + # if this eagermapper is to select using an "alias" to isolate it from other + # eager mappers against the same table, we have to redefine our secondary + # or primary join condition to reference the aliased table. else + # we set up the target clause objects as what they are defined in the + # superclass. + if self.selectalias is not None: + self.eagertarget = self.target.alias(self.selectalias) + aliasizer = Aliasizer(self.target, aliases={self.target:self.eagertarget}) + if self.secondaryjoin is not None: + self.eagersecondary = self.secondaryjoin.copy_container() + self.eagersecondary.accept_visitor(aliasizer) + self.eagerpriamry = self.primaryjoin + else: + self.eagerprimary = self.primaryjoin.copy_container() + self.eagerprimary.accept_visitor(aliasizer) + else: + self.eagertarget = self.target + self.eagerprimary = self.primaryjoin + self.eagersecondary = self.secondaryjoin def setup(self, key, statement, **options): """add a left outer join to the statement thats being constructed""" @@ -1173,14 +1193,14 @@ class EagerLoader(PropertyLoader): towrap = self.parent.table if self.secondaryjoin is not None: - statement._outerjoin = sql.outerjoin(towrap, self.secondary, self.primaryjoin).outerjoin(self.target, self.secondaryjoin) + statement._outerjoin = sql.outerjoin(towrap, self.secondary, self.primaryjoin).outerjoin(self.eagertarget, self.eagersecondary) statement.order_by(self.secondary.rowid_column) else: - statement._outerjoin = towrap.outerjoin(self.target, self.primaryjoin) + statement._outerjoin = towrap.outerjoin(self.eagertarget, self.eagerprimary) statement.order_by(self.target.rowid_column) statement.append_from(statement._outerjoin) - statement.append_column(self.target) + statement.append_column(self.eagertarget) for key, value in self.mapper.props.iteritems(): if value is self: raise "Cant use eager loading on a self-referential mapper relationship " + str(self.mapper) + " " + key + repr(self.mapper.props) @@ -1197,14 +1217,27 @@ class EagerLoader(PropertyLoader): if not self.uselist: if isnew: - h.setattr(self.mapper._instance(row, imap)) + h.setattr(self._instance(row, imap)) return elif isnew: result_list = h else: result_list = getattr(instance, self.key) - - self.mapper._instance(row, imap, result_list) + + self._instance(row, imap, result_list) + + def _instance(self, row, imap, result_list=None): + """gets an instance from a row, via this EagerLoader's mapper.""" + # if we have an alias for our mapper's table via the selectalias + # parameter, we need to translate the + # aliased columns from the incoming row into a new row that maps + # the values against the columns of the mapper's original non-aliased table. + if self.selectalias is not None: + fakerow = {} + for c in self.eagertarget.c: + fakerow[c.original] = row[c] + row = fakerow + return self.mapper._instance(row, imap, result_list) class MapperOption(object): """describes a modification to a Mapper in the context of making a copy @@ -1253,13 +1286,13 @@ class EagerLazyOption(MapperOption): class Aliasizer(sql.ClauseVisitor): """converts a table instance within an expression to be an alias of that table.""" - def __init__(self, *tables): + def __init__(self, *tables, **kwargs): self.tables = {} for t in tables: self.tables[t] = t self.binary = None self.match = False - self.aliases = {} + self.aliases = kwargs.get('aliases', {}) def get_alias(self, table): try: diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 3eb68d8b91..527d57023c 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -196,7 +196,11 @@ class Column(SchemaItem): def _make_proxy(self, selectable, name = None): """creates a copy of this Column, initialized the way this Column is""" - c = Column(name or self.name, self.type, self.foreign_key, self.sequence, key = name or self.key, primary_key = self.primary_key, hidden=self.hidden) + if self.foreign_key is None: + fk = None + else: + fk = self.foreign_key.copy() + c = Column(name or self.name, self.type, fk, self.sequence, key = name or self.key, primary_key = self.primary_key, hidden=self.hidden) c.table = selectable c._orig = self.original if not c.hidden: @@ -229,6 +233,16 @@ class ForeignKey(SchemaItem): return ForeignKey(self._colspec) else: return ForeignKey("%s.%s" % (self._colspec.table.name, self._colspec.column.key)) + + def references(self, table): + """returns True if the given table is referenced by this ForeignKey.""" + return ( + # simple test + self.column.table is table + or + # test for an indirect relation via a Selectable + table.get_col_by_original(self.column) is not None + ) def _init_column(self): # ForeignKey inits its remote column as late as possible, so tables can diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index fc4d427626..937bc90478 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -585,6 +585,13 @@ class Selectable(FromClause): def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) + def get_col_by_original(self, column): + """given a column which is a schema.Column object attached to a schema.Table object + (i.e. an "original" column), return the Column object from this + Selectable which corresponds to that original Column, or None if this Selectable + does not contain the column.""" + raise NotImplementedError() + def join(self, right, *args, **kwargs): return Join(self, right, *args, **kwargs) @@ -626,6 +633,13 @@ class Join(Selectable): statement""" return True + def get_col_by_original(self, column): + for c in self.columns: + if c.original is column: + return c + else: + return None + def hash_key(self): return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter)) @@ -661,6 +675,7 @@ class Alias(Selectable): def __init__(self, selectable, alias = None): self.selectable = selectable self.columns = util.OrderedProperties() + self.foreign_keys = [] if alias is None: alias = id(self) self.name = alias @@ -671,10 +686,13 @@ class Alias(Selectable): co._make_proxy(self) primary_keys = property (lambda self: [c for c in self.columns if c.primary_key]) - + def hash_key(self): return "Alias(%s, %s)" % (repr(self.selectable.hash_key()), repr(self.name)) + def get_col_by_original(self, column): + return self.columns.get(column.key, None) + def accept_visitor(self, visitor): self.selectable.accept_visitor(visitor) visitor.visit_alias(self) @@ -708,6 +726,12 @@ class ColumnImpl(Selectable, CompareMixin): def copy_container(self): return self.column + def get_col_by_original(self, column): + if self.column.original is column: + return self.column + else: + return None + def group_parenthesized(self): return False @@ -746,7 +770,10 @@ class TableImpl(Selectable): return self.table.name engine = property(lambda s: s.table.engine) - + + def get_col_by_original(self, column): + return self.columns.get(column.key, None) + def group_parenthesized(self): return False @@ -865,6 +892,13 @@ class Select(Selectable): else: co._make_proxy(self) + + def get_col_by_original(self, column): + if self.use_labels: + return self.columns.get(column.label,None) + else: + return self.columns.get(column.key,None) + def append_whereclause(self, whereclause): self._append_condition('whereclause', whereclause) def append_having(self, having): @@ -904,7 +938,7 @@ class Select(Selectable): def _get_froms(self): return [f for f in self._froms.values() if self._correlated is None or not self._correlated.has_key(f.id)] froms = property(lambda s: s._get_froms()) - + def accept_visitor(self, visitor): for f in self.froms: f.accept_visitor(visitor) diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index c474e3d101..32a8efa41a 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -27,6 +27,10 @@ class OrderedProperties(object): self.__dict__['_list'] = [] def keys(self): return self._list + def get(self, key, default): + return getattr(self, key, default) + def has_key(self, key): + return hasattr(self, key) def __iter__(self): return iter([self[x] for x in self._list]) def __setitem__(self, key, object):