]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- merged "find the equivalent columns" logic together (although both methodologies...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 May 2007 23:21:56 +0000 (23:21 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 May 2007 23:21:56 +0000 (23:21 +0000)
- uniqueappender has to use a set to handle staggered joins

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/sql_util.py
lib/sqlalchemy/util.py
test/orm/eagertest2.py

index 24812bbedbf4a139ec47cacb761fd72b5ebb9f04..f4cb6bb36b6ce7117f5b54494528a46956b54278 100644 (file)
@@ -888,6 +888,9 @@ class ResultProxy(object):
                 self.__keys.append(colname)
                 self.__props[i] = rec
 
+            if self.__echo:
+                self.context.engine.logger.debug("Cls " + repr(tuple([x[0] for x in metadata])))
+
     def close(self):
         """Close this ResultProxy, and the underlying DBAPI cursor corresponding to the execution.
 
index 5567a7824a5eb8071dc7f00a58b1a6606f040216..7453d1caa5197a30369e7d2056f5ff6130812727 100644 (file)
@@ -523,20 +523,14 @@ class Mapper(object):
         # into one column, where "equivalent" means that one column references the other via foreign key, or
         # multiple columns that all reference a common parent column.  it will also resolve the column
         # against the "mapped_table" of this mapper.
+        equivalent_columns = self._get_equivalent_columns()
+        
         primary_key = sql.ColumnCollection()
-        # TODO: wrong !  this is a duplicate / slightly different approach to 
-        # _get_inherited_column_equivalents().  pick one approach and stick with it !
-        equivs = {}
+
         for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
-            if not len(col.foreign_keys):
-                equivs.setdefault(col, util.Set()).add(col)
-            else:
-                for fk in col.foreign_keys:
-                    equivs.setdefault(fk.column, util.Set()).add(col)
-        for col in equivs:
             c = self.mapped_table.corresponding_column(col, raiseerr=False)
             if c is None:
-                for cc in equivs[col]:
+                for cc in equivalent_columns[col]:
                     c = self.mapped_table.corresponding_column(cc, raiseerr=False)
                     if c is not None:
                         break
@@ -548,12 +542,66 @@ class Mapper(object):
             raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
 
         self.primary_key = primary_key
-
+        self.__log("Identified primary key columns: " + str(primary_key))
+        
         _get_clause = sql.and_()
         for primary_key in self.primary_key:
             _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type, unique=True))
         self._get_clause = _get_clause
 
+    def _get_equivalent_columns(self):
+        """Create a map of all *equivalent* columns, based on
+        the determination of column pairs that are equated to
+        one another either by an established foreign key relationship
+        or by a joined-table inheritance join.
+
+        This is used to determine the minimal set of primary key 
+        columns for the mapper, as well as when relating 
+        columns to those of a polymorphic selectable (i.e. a UNION of
+        several mapped tables), as that selectable usually only contains
+        one column in its columns clause out of a group of several which
+        are equated to each other.
+
+        The resulting structure is a dictionary of columns mapped
+        to lists of equivalent columns, i.e.
+
+        {
+            tablea.col1: 
+                set([tableb.col1, tablec.col1]),
+            tablea.col2:
+                set([tabled.col2])
+        }
+        
+        this method is called repeatedly during the compilation process as 
+        the resulting dictionary contains more equivalents as more inheriting 
+        mappers are compiled.  the repetition of this process may be open to some optimization.
+        """
+
+        result = {}
+        def visit_binary(binary):
+            if binary.operator == '=':
+                if binary.left in result:
+                    result[binary.left].add(binary.right)
+                else:
+                    result[binary.left] = util.Set([binary.right])
+                if binary.right in result:
+                    result[binary.right].add(binary.left)
+                else:
+                    result[binary.right] = util.Set([binary.left])
+        vis = mapperutil.BinaryVisitor(visit_binary)
+
+        for mapper in self.base_mapper().polymorphic_iterator():
+            if mapper.inherit_condition is not None:
+                vis.traverse(mapper.inherit_condition)
+
+        for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
+            if not len(col.foreign_keys):
+                result.setdefault(col, util.Set()).add(col)
+            else:
+                for fk in col.foreign_keys:
+                    result.setdefault(fk.column, util.Set()).add(col)
+
+        return result
         
     def _compile_properties(self):
         """Inspect the properties dictionary sent to the Mapper's
@@ -769,43 +817,6 @@ class Mapper(object):
             for m in mapper.polymorphic_iterator():
                 yield m
 
-    def _get_inherited_column_equivalents(self):
-        """Return a map of all *equivalent* columns, based on
-        traversing the full set of inherit_conditions across all
-        inheriting mappers and determining column pairs that are
-        equated to one another.
-
-        This is used when relating columns to those of a polymorphic
-        selectable, as the selectable usually only contains one of two (or more)
-        columns that are equated to one another.
-        
-        The resulting structure is a dictionary of columns mapped
-        to lists of equivalent columns, i.e.
-        
-        {
-            tablea.col1: 
-                [tableb.col1, tablec.col1],
-            tablea.col2:
-                [tabled.col2]
-        }
-        """
-
-        result = {}
-        def visit_binary(binary):
-            if binary.operator == '=':
-                if binary.left in result:
-                    result[binary.left].append(binary.right)
-                else:
-                    result[binary.left] = [binary.right]
-                if binary.right in result:
-                    result[binary.right].append(binary.left)
-                else:
-                    result[binary.right] = [binary.left]
-        vis = mapperutil.BinaryVisitor(visit_binary)
-        for mapper in self.base_mapper().polymorphic_iterator():
-            if mapper.inherit_condition is not None:
-                vis.traverse(mapper.inherit_condition)
-        return result
 
     def add_properties(self, dict_of_properties):
         """Add the given dictionary of properties to this mapper,
index 6677e4ab94628afb92a7351e408a120732cee0ec..5e65fffb389ec5cd00cb4751c8ff359a20b7ae78 100644 (file)
@@ -383,8 +383,9 @@ class PropertyLoader(StrategizedProperty):
         # as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out,
         # first create maps of all the "equivalent" columns, since polymorphic selectables will often munge
         # several "equivalent" columns (such as parent/child fk cols) into just one column.
-        target_equivalents = self.mapper._get_inherited_column_equivalents()
 
+        target_equivalents = self.mapper._get_equivalent_columns()
+            
         # if the target mapper loads polymorphically, adapt the clauses to the target's selectable
         if self.loads_polymorphic:
             if self.secondaryjoin:
@@ -403,7 +404,7 @@ class PropertyLoader(StrategizedProperty):
             for c in list(self.remote_side):
                 if self.secondary and c in self.secondary.columns:
                     continue
-                for equiv in [c] + (c in target_equivalents and target_equivalents[c] or []): 
+                for equiv in [c] + (c in target_equivalents and list(target_equivalents[c]) or []): 
                     corr = self.mapper.select_table.corresponding_column(equiv, raiseerr=False)
                     if corr:
                         self.remote_side.add(corr)
@@ -454,7 +455,7 @@ class PropertyLoader(StrategizedProperty):
         try:
             return self._parent_join_cache[(parent, primary, secondary)]
         except KeyError:
-            parent_equivalents = parent._get_inherited_column_equivalents()
+            parent_equivalents = parent._get_equivalent_columns()
             primaryjoin = self.polymorphic_primaryjoin.copy_container()
             if self.secondaryjoin is not None:
                 secondaryjoin = self.polymorphic_secondaryjoin.copy_container()
index debf1da4f9c5b5ca37216b2d9e360e3be74ec39b..9f8bf276ecef839b13db2093d93389086af93496 100644 (file)
@@ -67,12 +67,12 @@ class TableCollection(object):
 class TableFinder(TableCollection, sql.NoColumnVisitor):
     """locate all Tables within a clause."""
 
-    def __init__(self, table, check_columns=False, include_aliases=False):
+    def __init__(self, clause, check_columns=False, include_aliases=False):
         TableCollection.__init__(self)
         self.check_columns = check_columns
         self.include_aliases = include_aliases
-        if table is not None:
-            self.traverse(table)
+        if clause is not None:
+            self.traverse(clause)
 
     def visit_alias(self, alias):
         if self.include_aliases:
index 02dff674be51bae2d1dd988791661d2fd0f01cb1..cb32c07ac20f9a267b9c45a9189d936b6fb79b1c 100644 (file)
@@ -417,21 +417,22 @@ class OrderedSet(Set):
     __isub__ = difference_update
 
 class UniqueAppender(object):
-    """appends items to a list such that consecutive repeats of
-    a particular item are skipped."""
+    """appends items to a collection such that only unique items
+    are added."""
     
     def __init__(self, data):
         self.data = data
+        self._unique = Set()
         if hasattr(data, 'append'):
             self._data_appender = data.append
         elif hasattr(data, 'add'):
+            # TODO: we think its a set here.  bypass unneeded uniquing logic ?
             self._data_appender = data.add
-        self.__last = None
         
     def append(self, item):
-        if item is not self.__last:
+        if item not in self._unique:
             self._data_appender(item)
-            self.__last = item
+            self._unique.add(item)
     
     def __iter__(self):
         return iter(self.data)
index ef385df16314b9ed2d754bde13d47ab34b616d15..78e1d987072751cece645ab8baa5889d51c40e08 100644 (file)
@@ -231,9 +231,8 @@ class EagerTest(AssertMixin):
         ctx.current.clear()
 
         i = ctx.current.query(Invoice).get(invoice_id)
-        self.echo(repr(i))
 
-        self.assert_(repr(i.company) == repr(c))
+        assert repr(i.company) == repr(c), repr(i.company) +  " does not match " + repr(c)
         
 if __name__ == "__main__":    
     testbase.main()