From c0f47eefee5cc54342b714087bc0a74b80671489 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 26 May 2007 23:21:56 +0000 Subject: [PATCH] - merged "find the equivalent columns" logic together (although both methodologies are needed....) - uniqueappender has to use a set to handle staggered joins --- lib/sqlalchemy/engine/base.py | 3 + lib/sqlalchemy/orm/mapper.py | 107 +++++++++++++++++-------------- lib/sqlalchemy/orm/properties.py | 7 +- lib/sqlalchemy/sql_util.py | 6 +- lib/sqlalchemy/util.py | 11 ++-- test/orm/eagertest2.py | 3 +- 6 files changed, 76 insertions(+), 61 deletions(-) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 24812bbedb..f4cb6bb36b 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -888,6 +888,9 @@ class ResultProxy(object): self.__keys.append(colname) self.__props[i] = rec + if self.__echo: + self.context.engine.logger.debug("Cls " + repr(tuple([x[0] for x in metadata]))) + def close(self): """Close this ResultProxy, and the underlying DBAPI cursor corresponding to the execution. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 5567a7824a..7453d1caa5 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -523,20 +523,14 @@ class Mapper(object): # into one column, where "equivalent" means that one column references the other via foreign key, or # multiple columns that all reference a common parent column. it will also resolve the column # against the "mapped_table" of this mapper. + equivalent_columns = self._get_equivalent_columns() + primary_key = sql.ColumnCollection() - # TODO: wrong ! this is a duplicate / slightly different approach to - # _get_inherited_column_equivalents(). pick one approach and stick with it ! - equivs = {} + for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]): - if not len(col.foreign_keys): - equivs.setdefault(col, util.Set()).add(col) - else: - for fk in col.foreign_keys: - equivs.setdefault(fk.column, util.Set()).add(col) - for col in equivs: c = self.mapped_table.corresponding_column(col, raiseerr=False) if c is None: - for cc in equivs[col]: + for cc in equivalent_columns[col]: c = self.mapped_table.corresponding_column(cc, raiseerr=False) if c is not None: break @@ -548,12 +542,66 @@ class Mapper(object): raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name)) self.primary_key = primary_key - + self.__log("Identified primary key columns: " + str(primary_key)) + _get_clause = sql.and_() for primary_key in self.primary_key: _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type, unique=True)) self._get_clause = _get_clause + def _get_equivalent_columns(self): + """Create a map of all *equivalent* columns, based on + the determination of column pairs that are equated to + one another either by an established foreign key relationship + or by a joined-table inheritance join. + + This is used to determine the minimal set of primary key + columns for the mapper, as well as when relating + columns to those of a polymorphic selectable (i.e. a UNION of + several mapped tables), as that selectable usually only contains + one column in its columns clause out of a group of several which + are equated to each other. + + The resulting structure is a dictionary of columns mapped + to lists of equivalent columns, i.e. + + { + tablea.col1: + set([tableb.col1, tablec.col1]), + tablea.col2: + set([tabled.col2]) + } + + this method is called repeatedly during the compilation process as + the resulting dictionary contains more equivalents as more inheriting + mappers are compiled. the repetition of this process may be open to some optimization. + """ + + result = {} + def visit_binary(binary): + if binary.operator == '=': + if binary.left in result: + result[binary.left].add(binary.right) + else: + result[binary.left] = util.Set([binary.right]) + if binary.right in result: + result[binary.right].add(binary.left) + else: + result[binary.right] = util.Set([binary.left]) + vis = mapperutil.BinaryVisitor(visit_binary) + + for mapper in self.base_mapper().polymorphic_iterator(): + if mapper.inherit_condition is not None: + vis.traverse(mapper.inherit_condition) + + for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]): + if not len(col.foreign_keys): + result.setdefault(col, util.Set()).add(col) + else: + for fk in col.foreign_keys: + result.setdefault(fk.column, util.Set()).add(col) + + return result def _compile_properties(self): """Inspect the properties dictionary sent to the Mapper's @@ -769,43 +817,6 @@ class Mapper(object): for m in mapper.polymorphic_iterator(): yield m - def _get_inherited_column_equivalents(self): - """Return a map of all *equivalent* columns, based on - traversing the full set of inherit_conditions across all - inheriting mappers and determining column pairs that are - equated to one another. - - This is used when relating columns to those of a polymorphic - selectable, as the selectable usually only contains one of two (or more) - columns that are equated to one another. - - The resulting structure is a dictionary of columns mapped - to lists of equivalent columns, i.e. - - { - tablea.col1: - [tableb.col1, tablec.col1], - tablea.col2: - [tabled.col2] - } - """ - - result = {} - def visit_binary(binary): - if binary.operator == '=': - if binary.left in result: - result[binary.left].append(binary.right) - else: - result[binary.left] = [binary.right] - if binary.right in result: - result[binary.right].append(binary.left) - else: - result[binary.right] = [binary.left] - vis = mapperutil.BinaryVisitor(visit_binary) - for mapper in self.base_mapper().polymorphic_iterator(): - if mapper.inherit_condition is not None: - vis.traverse(mapper.inherit_condition) - return result def add_properties(self, dict_of_properties): """Add the given dictionary of properties to this mapper, diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 6677e4ab94..5e65fffb38 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -383,8 +383,9 @@ class PropertyLoader(StrategizedProperty): # as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out, # first create maps of all the "equivalent" columns, since polymorphic selectables will often munge # several "equivalent" columns (such as parent/child fk cols) into just one column. - target_equivalents = self.mapper._get_inherited_column_equivalents() + target_equivalents = self.mapper._get_equivalent_columns() + # if the target mapper loads polymorphically, adapt the clauses to the target's selectable if self.loads_polymorphic: if self.secondaryjoin: @@ -403,7 +404,7 @@ class PropertyLoader(StrategizedProperty): for c in list(self.remote_side): if self.secondary and c in self.secondary.columns: continue - for equiv in [c] + (c in target_equivalents and target_equivalents[c] or []): + for equiv in [c] + (c in target_equivalents and list(target_equivalents[c]) or []): corr = self.mapper.select_table.corresponding_column(equiv, raiseerr=False) if corr: self.remote_side.add(corr) @@ -454,7 +455,7 @@ class PropertyLoader(StrategizedProperty): try: return self._parent_join_cache[(parent, primary, secondary)] except KeyError: - parent_equivalents = parent._get_inherited_column_equivalents() + parent_equivalents = parent._get_equivalent_columns() primaryjoin = self.polymorphic_primaryjoin.copy_container() if self.secondaryjoin is not None: secondaryjoin = self.polymorphic_secondaryjoin.copy_container() diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py index debf1da4f9..9f8bf276ec 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql_util.py @@ -67,12 +67,12 @@ class TableCollection(object): class TableFinder(TableCollection, sql.NoColumnVisitor): """locate all Tables within a clause.""" - def __init__(self, table, check_columns=False, include_aliases=False): + def __init__(self, clause, check_columns=False, include_aliases=False): TableCollection.__init__(self) self.check_columns = check_columns self.include_aliases = include_aliases - if table is not None: - self.traverse(table) + if clause is not None: + self.traverse(clause) def visit_alias(self, alias): if self.include_aliases: diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 02dff674be..cb32c07ac2 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -417,21 +417,22 @@ class OrderedSet(Set): __isub__ = difference_update class UniqueAppender(object): - """appends items to a list such that consecutive repeats of - a particular item are skipped.""" + """appends items to a collection such that only unique items + are added.""" def __init__(self, data): self.data = data + self._unique = Set() if hasattr(data, 'append'): self._data_appender = data.append elif hasattr(data, 'add'): + # TODO: we think its a set here. bypass unneeded uniquing logic ? self._data_appender = data.add - self.__last = None def append(self, item): - if item is not self.__last: + if item not in self._unique: self._data_appender(item) - self.__last = item + self._unique.add(item) def __iter__(self): return iter(self.data) diff --git a/test/orm/eagertest2.py b/test/orm/eagertest2.py index ef385df163..78e1d98707 100644 --- a/test/orm/eagertest2.py +++ b/test/orm/eagertest2.py @@ -231,9 +231,8 @@ class EagerTest(AssertMixin): ctx.current.clear() i = ctx.current.query(Invoice).get(invoice_id) - self.echo(repr(i)) - self.assert_(repr(i.company) == repr(c)) + assert repr(i.company) == repr(c), repr(i.company) + " does not match " + repr(c) if __name__ == "__main__": testbase.main() -- 2.47.3