]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added functionality to map columns to their aliased versions.
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Dec 2005 08:49:45 +0000 (08:49 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Dec 2005 08:49:45 +0000 (08:49 +0000)
added support for specifying an alias in a relation.
added a new relation flag 'selectalias' which causes eagerloader to use a local alias name for its target table, translating columns back to the original non-aliased name as result rows come in.

lib/sqlalchemy/mapper.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/util.py

index 86527a1105bb82bdb721e6662f5096e834de1b54..12b3d8e61acce31f30f944230116ffcbb3e1e4ae 100644 (file)
@@ -51,10 +51,10 @@ def _relation_loader(mapper, secondary=None, primaryjoin=None, secondaryjoin=Non
 
 def _relation_mapper(class_, table=None, secondary=None, 
                     primaryjoin=None, secondaryjoin=None, 
-                    foreignkey=None, uselist=None, private=False, live=False, association=None, lazy=True, **kwargs):
+                    foreignkey=None, uselist=None, private=False, live=False, association=None, lazy=True, selectalias=None, **kwargs):
 
     return _relation_loader(mapper(class_, table, **kwargs), secondary, primaryjoin, secondaryjoin, 
-                    foreignkey=foreignkey, uselist=uselist, private=private, live=live, association=association, lazy=lazy)
+                    foreignkey=foreignkey, uselist=uselist, private=private, live=live, association=association, lazy=lazy, selectalias=selectalias)
 
 #def _relation_mapper(class_, table=None, secondary=None, 
 #                    primaryjoin=None, secondaryjoin=None, foreignkey=None, 
@@ -444,7 +444,7 @@ class Mapper(object):
 
     def _getattrbycolumn(self, obj, column):
         try:
-            prop = self.columntoproperty[column]
+            prop = self.columntoproperty[column.original]
         except KeyError:
             try:
                 prop = self.props[column.key]
@@ -455,7 +455,7 @@ class Mapper(object):
         return prop[0].getattr(obj)
 
     def _setattrbycolumn(self, obj, column, value):
-        self.columntoproperty[column][0].setattr(obj, value)
+        self.columntoproperty[column.original][0].setattr(obj, value)
 
         
     def save_obj(self, objects, uow):
@@ -700,7 +700,7 @@ class PropertyLoader(MapperProperty):
 
     """describes an object property that holds a single item or list of items that correspond
     to a related database table."""
-    def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, live=False, isoption=False, association=None, **kwargs):
+    def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, live=False, isoption=False, association=None, selectalias=None, **kwargs):
         self.uselist = uselist
         self.argument = argument
         self.secondary = secondary
@@ -711,6 +711,7 @@ class PropertyLoader(MapperProperty):
         self.live = live
         self.isoption = isoption
         self.association = association
+        self.selectalias = selectalias
         self._hash_key = "%s(%s, %s, %s, %s, %s, %s, %s)" % (self.__class__.__name__, hash_key(self.argument), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), hash_key(foreignkey), repr(uselist), repr(private))
 
     def _copy(self):
@@ -728,7 +729,7 @@ class PropertyLoader(MapperProperty):
         if self.association is not None:
             if isinstance(self.association, type):
                 self.association = class_mapper(self.association)
-                
+        
         self.target = self.mapper.table
         self.key = key
         self.parent = parent
@@ -812,16 +813,15 @@ class PropertyLoader(MapperProperty):
     def _match_primaries(self, primary, secondary):
         crit = []
         for fk in secondary.foreign_keys:
-            if fk.column.table is primary:
-                crit.append(fk.column == fk.parent)
+            if fk.references(primary):
+                crit.append(primary.get_col_by_original(fk.column) == fk.parent)
                 self.foreignkey = fk.parent
         for fk in primary.foreign_keys:
-            if fk.column.table is secondary:
-                crit.append(fk.column == fk.parent)
+            if fk.references(secondary):
+                crit.append(secondary.get_col_by_original(fk.column) == fk.parent)
                 self.foreignkey = fk.parent
-
         if len(crit) == 0:
-            raise "Cant find any foreign key relationships between '%s' (%s) and '%s' (%s)" % (primary.table.name, repr(primary.table), secondary.table.name, repr(secondary.table))
+            raise "Cant find any foreign key relationships between '%s' (%s) and '%s' (%s)" % (primary.name, repr(primary), secondary.name, repr(secondary))
         elif len(crit) == 1:
             return (crit[0])
         else:
@@ -1154,6 +1154,26 @@ class EagerLoader(PropertyLoader):
         if self.secondaryjoin is not None:
             [self.to_alias.append(f) for f in self.secondaryjoin._get_from_objects()]
         del self.to_alias[parent.primarytable]
+    
+        # if this eagermapper is to select using an "alias" to isolate it from other
+        # eager mappers against the same table, we have to redefine our secondary
+        # or primary join condition to reference the aliased table.  else
+        # we set up the target clause objects as what they are defined in the 
+        # superclass.
+        if self.selectalias is not None:
+            self.eagertarget = self.target.alias(self.selectalias)
+            aliasizer = Aliasizer(self.target, aliases={self.target:self.eagertarget})
+            if self.secondaryjoin is not None:
+                self.eagersecondary = self.secondaryjoin.copy_container()
+                self.eagersecondary.accept_visitor(aliasizer)
+                self.eagerpriamry = self.primaryjoin
+            else:
+                self.eagerprimary = self.primaryjoin.copy_container()
+                self.eagerprimary.accept_visitor(aliasizer)
+        else:
+            self.eagertarget = self.target
+            self.eagerprimary = self.primaryjoin
+            self.eagersecondary = self.secondaryjoin
 
     def setup(self, key, statement, **options):
         """add a left outer join to the statement thats being constructed"""
@@ -1173,14 +1193,14 @@ class EagerLoader(PropertyLoader):
             towrap = self.parent.table
 
         if self.secondaryjoin is not None:
-            statement._outerjoin = sql.outerjoin(towrap, self.secondary, self.primaryjoin).outerjoin(self.target, self.secondaryjoin)
+            statement._outerjoin = sql.outerjoin(towrap, self.secondary, self.primaryjoin).outerjoin(self.eagertarget, self.eagersecondary)
             statement.order_by(self.secondary.rowid_column)
         else:
-            statement._outerjoin = towrap.outerjoin(self.target, self.primaryjoin)
+            statement._outerjoin = towrap.outerjoin(self.eagertarget, self.eagerprimary)
             statement.order_by(self.target.rowid_column)
 
         statement.append_from(statement._outerjoin)
-        statement.append_column(self.target)
+        statement.append_column(self.eagertarget)
         for key, value in self.mapper.props.iteritems():
             if value is self:
                 raise "Cant use eager loading on a self-referential mapper relationship " + str(self.mapper) + " " + key + repr(self.mapper.props)
@@ -1197,14 +1217,27 @@ class EagerLoader(PropertyLoader):
             
         if not self.uselist:
             if isnew:
-                h.setattr(self.mapper._instance(row, imap))
+                h.setattr(self._instance(row, imap))
             return
         elif isnew:
             result_list = h
         else:
             result_list = getattr(instance, self.key)
-            
-        self.mapper._instance(row, imap, result_list)
+    
+        self._instance(row, imap, result_list)
+
+    def _instance(self, row, imap, result_list=None):
+        """gets an instance from a row, via this EagerLoader's mapper."""
+        # if we have an alias for our mapper's table via the selectalias
+        # parameter, we need to translate the 
+        # aliased columns from the incoming row into a new row that maps
+        # the values against the columns of the mapper's original non-aliased table.
+        if self.selectalias is not None:
+            fakerow = {}
+            for c in self.eagertarget.c:
+                fakerow[c.original] = row[c]
+            row = fakerow
+        return self.mapper._instance(row, imap, result_list)
 
 class MapperOption(object):
     """describes a modification to a Mapper in the context of making a copy
@@ -1253,13 +1286,13 @@ class EagerLazyOption(MapperOption):
 
 class Aliasizer(sql.ClauseVisitor):
     """converts a table instance within an expression to be an alias of that table."""
-    def __init__(self, *tables):
+    def __init__(self, *tables, **kwargs):
         self.tables = {}
         for t in tables:
             self.tables[t] = t
         self.binary = None
         self.match = False
-        self.aliases = {}
+        self.aliases = kwargs.get('aliases', {})
         
     def get_alias(self, table):
         try:
index 3eb68d8b9184b044efbd78578cc695e9589958bd..527d57023c6ed85ce60d3a05a196795592b5a4ef 100644 (file)
@@ -196,7 +196,11 @@ class Column(SchemaItem):
         
     def _make_proxy(self, selectable, name = None):
         """creates a copy of this Column, initialized the way this Column is"""
-        c = Column(name or self.name, self.type, self.foreign_key, self.sequence, key = name or self.key, primary_key = self.primary_key, hidden=self.hidden)
+        if self.foreign_key is None:
+            fk = None
+        else:
+            fk = self.foreign_key.copy()
+        c = Column(name or self.name, self.type, fk, self.sequence, key = name or self.key, primary_key = self.primary_key, hidden=self.hidden)
         c.table = selectable
         c._orig = self.original
         if not c.hidden:
@@ -229,6 +233,16 @@ class ForeignKey(SchemaItem):
             return ForeignKey(self._colspec)
         else:
             return ForeignKey("%s.%s" % (self._colspec.table.name, self._colspec.column.key))
+    
+    def references(self, table):
+        """returns True if the given table is referenced by this ForeignKey."""
+        return (
+            # simple test
+            self.column.table is table      
+            or
+            # test for an indirect relation via a Selectable
+            table.get_col_by_original(self.column) is not None
+        )
         
     def _init_column(self):
         # ForeignKey inits its remote column as late as possible, so tables can
index fc4d4276269ea07af4ba73f334364e16f2f41abd..937bc90478553b1011599fe30e44ab720ff83f5f 100644 (file)
@@ -585,6 +585,13 @@ class Selectable(FromClause):
     def select(self, whereclauses = None, **params):
         return select([self], whereclauses, **params)
 
+    def get_col_by_original(self, column):
+        """given a column which is a schema.Column object attached to a schema.Table object
+        (i.e. an "original" column), return the Column object from this 
+        Selectable which corresponds to that original Column, or None if this Selectable
+        does not contain the column."""
+        raise NotImplementedError()
+
     def join(self, right, *args, **kwargs):
         return Join(self, right, *args, **kwargs)
 
@@ -626,6 +633,13 @@ class Join(Selectable):
         statement"""
         return True
 
+    def get_col_by_original(self, column):
+        for c in self.columns:
+            if c.original is column:
+                return c
+        else:
+            return None
+
     def hash_key(self):
         return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter))
 
@@ -661,6 +675,7 @@ class Alias(Selectable):
     def __init__(self, selectable, alias = None):
         self.selectable = selectable
         self.columns = util.OrderedProperties()
+        self.foreign_keys = []
         if alias is None:
             alias = id(self)
         self.name = alias
@@ -671,10 +686,13 @@ class Alias(Selectable):
             co._make_proxy(self)
 
     primary_keys = property (lambda self: [c for c in self.columns if c.primary_key])
-
+    
     def hash_key(self):
         return "Alias(%s, %s)" % (repr(self.selectable.hash_key()), repr(self.name))
 
+    def get_col_by_original(self, column):
+        return self.columns.get(column.key, None)
+
     def accept_visitor(self, visitor):
         self.selectable.accept_visitor(visitor)
         visitor.visit_alias(self)
@@ -708,6 +726,12 @@ class ColumnImpl(Selectable, CompareMixin):
     def copy_container(self):
         return self.column
 
+    def get_col_by_original(self, column):
+        if self.column.original is column:
+            return self.column
+        else:
+            return None
+            
     def group_parenthesized(self):
         return False
         
@@ -746,7 +770,10 @@ class TableImpl(Selectable):
         return self.table.name
     
     engine = property(lambda s: s.table.engine)
-    
+
+    def get_col_by_original(self, column):
+        return self.columns.get(column.key, None)
+
     def group_parenthesized(self):
         return False
 
@@ -865,6 +892,13 @@ class Select(Selectable):
             else:
                 co._make_proxy(self)
 
+
+    def get_col_by_original(self, column):
+        if self.use_labels:
+            return self.columns.get(column.label,None)
+        else:
+            return self.columns.get(column.key,None)
+
     def append_whereclause(self, whereclause):
         self._append_condition('whereclause', whereclause)
     def append_having(self, having):
@@ -904,7 +938,7 @@ class Select(Selectable):
     def _get_froms(self):
         return [f for f in self._froms.values() if self._correlated is None or not self._correlated.has_key(f.id)]
     froms = property(lambda s: s._get_froms())
-    
+
     def accept_visitor(self, visitor):
         for f in self.froms:
             f.accept_visitor(visitor)
index c474e3d10161f724f6793767e774a21883b30bbb..32a8efa41ada62ac3a9f0f061c98681813f6f59b 100644 (file)
@@ -27,6 +27,10 @@ class OrderedProperties(object):
         self.__dict__['_list'] = []
     def keys(self):
         return self._list
+    def get(self, key, default):
+        return getattr(self, key, default)
+    def has_key(self, key):
+        return hasattr(self, key)
     def __iter__(self):
         return iter([self[x] for x in self._list])
     def __setitem__(self, key, object):