From: Mike Bayer Date: Sun, 28 Jan 2007 23:33:53 +0000 (+0000) Subject: merged the polymorphic relationship refactoring branch in. i want to go further... X-Git-Tag: rel_0_3_5~83 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bbc5e7c285a160f148eafa0ab442675fe88551ce;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git merged the polymorphic relationship refactoring branch in. i want to go further on that branch and introduce the foreign_keys argument, and further centralize the "intelligence" about the joins and selectables into PropertyLoader so that lazyloader/sync can be simplified, but the current branch goes pretty far. - relations keep track of "polymorphic_primaryjoin", "polymorphic_secondaryjoin" which it derives from the plain primaryjoin/secondaryjoin. - lazy/eagerloaders work from those polymorphic join objects. - the join exported by PropertyLoader to Query/SelectResults is the polymorphic join, so that join_to/etc work properly. - Query builds itself against the base Mapper again, not the "polymorphic" mapper. uses the "polymorphic" version only as appropriate. this helps join_by/join_to/etc to work with polymorphic mappers. - Query will also adapt incoming WHERE criterion to the polymorphic mapper, i.e. the "people" table becomes the "person_join" automatically. - quoting has been modified since labels made out of non-case-sensitive columns could themselves require quoting..so case_sensitive defaults to True if not otherwise specified (used to be based on the identifier itself). - the test harness gets an ORMTest base class and a bunch of the ORM unit tests are using it now, decreases a lot of redundancy. --- diff --git a/CHANGES b/CHANGES index 140139987a..827cc05898 100644 --- a/CHANGES +++ b/CHANGES @@ -1,12 +1,18 @@ 0.3.5 - orm: - - adjustments to the recent polymorphic relationship refactorings, specifically - for many-to-one relationships to polymorphic unions that did not contain the - base table [ticket:439]. the lazy/eager clause adaption to the selectable - will match up on straight column names (i.e. its a more liberal policy) + - further rework of the recent polymorphic relationship refactorings, as well + as the mechanics of relationships overall. Allows more accurate ORM behavior + with relationships from/to/between polymorphic mappers, as well as their usage + with Query, SelectResults. tickets include [ticket:439], [ticket:441]. + relationship mechanics are still a work in progress, more to come ! - eager relation to an inheriting mapper wont fail if no rows returned for the relationship. - - fix for multi-level polymorphic mappers + - eager loading is slightly more strict about detecting "self-referential" + relationships, specifically between polymorphic mappers. + - the value of "case_sensitive" defaults to True now, regardless of the casing + of the identifier, unless specifically set to False. this is because the + object might be label'ed as something else which does contain mixed case, and + propigating "case_sensitive=False" breaks that. - oracle: - when returning "rowid" as the ORDER BY column or in use with ROW_NUMBER OVER, oracle dialect checks the selectable its being applied to and will switch to diff --git a/lib/sqlalchemy/ext/selectresults.py b/lib/sqlalchemy/ext/selectresults.py index c2ad409179..153b0c2b9a 100644 --- a/lib/sqlalchemy/ext/selectresults.py +++ b/lib/sqlalchemy/ext/selectresults.py @@ -71,7 +71,10 @@ class SelectResults(object): def select(self, clause): return self.filter(clause) - + + def select_by(self, *args, **kwargs): + return self.filter(self._query._join_by(args, kwargs, start=self._joinpoint[1])) + def order_by(self, order_by): """apply an ORDER BY to the query.""" new = self.clone() @@ -131,9 +134,12 @@ class SelectResults(object): for key in keys: prop = mapper.props[key] if outerjoin: - clause = clause.outerjoin(prop.mapper.mapped_table, prop.get_join()) + clause = clause.outerjoin(prop.select_table, prop.get_join()) else: - clause = clause.join(prop.mapper.mapped_table, prop.get_join()) + clause = clause.join(prop.select_table, prop.get_join()) + print "SELECT_TABLE", prop.select_table + print "JOIN", prop.get_join() + print "CLAUSE", str(clause), "DONE CLAUSE" mapper = prop.mapper return (clause, mapper) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index e0ee36fac6..d6967934cd 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -500,8 +500,9 @@ class Mapper(object): def _initialize_properties(self): - """calls the init() method on all MapperProperties attached to this mapper. this will incur the - compilation of related mappers.""" + """calls the init() method on all MapperProperties attached to this mapper. this happens + after all mappers have completed compiling everything else up until this point, so that all + dependencies are fully available.""" self.__log("_initialize_properties() started") l = [(key, prop) for key, prop in self.__props.iteritems()] for key, prop in l: @@ -514,7 +515,8 @@ class Mapper(object): """if the 'select_table' keyword argument was specified, set up a second "surrogate mapper" that will be used for select operations. the columns of select_table should encompass all the columns of the mapped_table either directly - or through proxying relationships.""" + or through proxying relationships. Currently, non-column properties are *not* copied. this implies + that a polymorphic mapper cant do any eager loading right now.""" if self.select_table is not self.mapped_table: if self.polymorphic_identity is None: raise exceptions.ArgumentError("Could not locate a polymorphic_identity field for mapper '%s'. This field is required for polymorphic mappers" % str(self)) @@ -531,7 +533,7 @@ class Mapper(object): """if this mapper is to be a primary mapper (i.e. the non_primary flag is not set), associate this Mapper with the given class_ and entity name. subsequent calls to class_mapper() for the class_/entity name combination will return this - mapper. also decorates the __init__ method on the mapped class to include auto-session attachment logic.""" + mapper. also decorates the __init__ method on the mapped class to include optional auto-session attachment logic.""" if self.non_primary: return @@ -626,7 +628,24 @@ class Mapper(object): for x in iterate(mapper): yield x return iterate(self) - + + def _get_inherited_column_equivalents(self): + """return a map of all 'equivalent' columns, based on traversing the full set of inherit_conditions across + all inheriting mappers and determining column pairs that are equated to one another. + + this is used when relating columns to those of a polymorphic selectable, as the selectable usually only contains + one of two columns that are equated to one another.""" + result = {} + def visit_binary(binary): + if binary.operator == '=': + result[binary.left] = binary.right + result[binary.right] = binary.left + vis = mapperutil.BinaryVisitor(visit_binary) + for mapper in self.polymorphic_iterator(): + if mapper.inherit_condition is not None: + mapper.inherit_condition.accept_visitor(vis) + return result + def add_properties(self, dict_of_properties): """adds the given dictionary of properties to this mapper, using add_property.""" for key, value in dict_of_properties.iteritems(): @@ -755,8 +774,8 @@ class Mapper(object): def identity_key_from_row(self, row): """return an identity-map key for use in storing/retrieving an item from the identity map. - row - a sqlalchemy.dbengine.RowProxy instance or other map corresponding result-set - column names to their values within a row. + row - a sqlalchemy.engine.base.RowProxy instance or a dictionary corresponding result-set + ColumnElement instances to their values within a row. """ return (self.class_, tuple([row[column] for column in self.pks_by_table[self.mapped_table]]), self.entity_name) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index bc74fcafd3..42c017bc5e 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -201,6 +201,9 @@ class PropertyLoader(StrategizedProperty): self.association = mapper.class_mapper(self.association, compile=False)._check_compile() self.target = self.mapper.mapped_table + self.select_mapper = self.mapper.get_select_mapper() + self.select_table = self.mapper.select_table + self.loads_polymorphic = self.target is not self.select_table if self.cascade.delete_orphan: if self.parent.class_ is self.mapper.class_: @@ -226,7 +229,7 @@ class PropertyLoader(StrategizedProperty): # as the loader strategies expect to be working with those now (they will adapt the join conditions # to the "polymorphic" selectable as needed). since this is an API change, put an explicit check/ # error message in case its the "old" way. - if self.mapper.select_table is not self.mapper.mapped_table: + if self.loads_polymorphic: vis = sql_util.ColumnsInClause(self.mapper.select_table) self.primaryjoin.accept_visitor(vis) if self.secondaryjoin: @@ -234,6 +237,7 @@ class PropertyLoader(StrategizedProperty): if vis.result: raise exceptions.ArgumentError("In relationship '%s' between mappers '%s' and '%s', primary and secondary join conditions must not include columns from the polymorphic 'select_table' argument as of SA release 0.3.4. Construct join conditions using the base tables of the related mappers." % (self.key, self.parent, self.mapper)) + # if the foreign key wasnt specified and theres no assocaition table, try to figure # out who is dependent on who. we dont need all the foreign keys represented in the join, # just one of them. @@ -247,6 +251,52 @@ class PropertyLoader(StrategizedProperty): if self.direction is None: self.direction = self._get_direction() + #print "DIRECTION IS ", self.direction, sync.ONETOMANY, sync.MANYTOONE + #print "FKEY IS", self.foreignkey + + # get ready to create "polymorphic" primary/secondary join clauses. + # these clauses represent the same join between parent/child tables that the primary + # and secondary join clauses represent, except they reference ColumnElements that are specifically + # in the "polymorphic" selectables. these are used to construct joins for both Query as well as + # eager loading, and also are used to calculate "lazy loading" clauses. + + # as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out, + # first create maps of all the "equivalent" columns, since polymorphic selectables will often munge + # several "equivalent" columns (such as parent/child fk cols) into just one column. + parent_equivalents = self.parent._get_inherited_column_equivalents() + target_equivalents = self.mapper._get_inherited_column_equivalents() + + # if the target mapper loads polymorphically, adapt the clauses to the target's selectable + if self.loads_polymorphic: + if self.secondaryjoin: + self.polymorphic_secondaryjoin = self.secondaryjoin.copy_container() + self.polymorphic_secondaryjoin.accept_visitor(sql_util.ClauseAdapter(self.mapper.select_table)) + self.polymorphic_primaryjoin = self.primaryjoin.copy_container() + else: + self.polymorphic_primaryjoin = self.primaryjoin.copy_container() + if self.direction is sync.ONETOMANY: + self.polymorphic_primaryjoin.accept_visitor(sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreignkey, equivalents=target_equivalents)) + elif self.direction is sync.MANYTOONE: + self.polymorphic_primaryjoin.accept_visitor(sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreignkey, equivalents=target_equivalents)) + + self.polymorphic_secondaryjoin = None + else: + self.polymorphic_primaryjoin = self.primaryjoin.copy_container() + self.polymorphic_secondaryjoin = self.secondaryjoin and self.secondaryjoin.copy_container() or None + + # if the parent mapper loads polymorphically, adapt the clauses to the parent's selectable + if self.parent.select_table is not self.parent.mapped_table: + if self.direction is sync.ONETOMANY: + self.polymorphic_primaryjoin.accept_visitor(sql_util.ClauseAdapter(self.parent.select_table, exclude=self.foreignkey, equivalents=parent_equivalents)) + elif self.direction is sync.MANYTOONE: + self.polymorphic_primaryjoin.accept_visitor(sql_util.ClauseAdapter(self.parent.select_table, include=self.foreignkey, equivalents=parent_equivalents)) + elif self.secondaryjoin: + self.polymorphic_primaryjoin.accept_visitor(sql_util.ClauseAdapter(self.parent.select_table, exclude=self.foreignkey, equivalents=parent_equivalents)) + + #print "KEY", self.key, "PARENT", str(self.parent) + #print "KEY", self.key, "REG PRIMARY JOIN", str(self.primaryjoin) + #print "KEY", self.key, "POLY PRIMARY JOIN", str(self.polymorphic_primaryjoin) + if self.uselist is None and self.direction == sync.MANYTOONE: self.uselist = False @@ -326,10 +376,10 @@ class PropertyLoader(StrategizedProperty): self.foreignkey = foreignkeys def get_join(self): - if self.secondaryjoin is not None: - return self.primaryjoin & self.secondaryjoin + if self.polymorphic_secondaryjoin is not None: + return self.polymorphic_primaryjoin & self.polymorphic_secondaryjoin else: - return self.primaryjoin + return self.polymorphic_primaryjoin def register_dependencies(self, uowcommit): if not self.viewonly: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index ce2f78bb73..e0294ef72d 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -5,7 +5,7 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from sqlalchemy import sql, util, exceptions, sql_util, logging -from sqlalchemy.orm import mapper +from sqlalchemy.orm import mapper, class_mapper from sqlalchemy.orm.interfaces import OperationContext __all__ = ['Query', 'QueryContext', 'SelectionContext'] @@ -18,7 +18,7 @@ class Query(object): else: self.mapper = class_or_mapper.compile() self.with_options = with_options or [] - self.mapper = self.mapper.get_select_mapper().compile() + self.select_mapper = self.mapper.get_select_mapper().compile() self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh) self.order_by = kwargs.pop('order_by', self.mapper.order_by) self.lockmode = lockmode @@ -26,10 +26,11 @@ class Query(object): if extension is not None: self.extension.append(extension) self.extension.append(self.mapper.extension) + self.is_polymorphic = self.mapper is not self.select_mapper self._session = session if not hasattr(self.mapper, '_get_clause'): _get_clause = sql.and_() - for primary_key in self.mapper.pks_by_table[self.table]: + for primary_key in self.primary_key_columns: _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type)) self.mapper._get_clause = _get_clause self._get_clause = self.mapper._get_clause @@ -44,7 +45,8 @@ class Query(object): return self.mapper.get_session() else: return self._session - table = property(lambda s:s.mapper.select_table) + table = property(lambda s:s.select_mapper.mapped_table) + primary_key_columns = property(lambda s:s.select_mapper.pks_by_table[s.select_mapper.mapped_table]) session = property(_get_session) def get(self, ident, **kwargs): @@ -116,6 +118,10 @@ class Query(object): return self.select_whereclause(self.join_by(*args, **params)) def join_by(self, *args, **params): + """return a ClauseElement representing the WHERE clause that would normally be sent to select_whereclause() by select_by().""" + return self._join_by(args, params) + + def _join_by(self, args, params, start=None): """return a ClauseElement representing the WHERE clause that would normally be sent to select_whereclause() by select_by().""" clause = None for arg in args: @@ -125,7 +131,7 @@ class Query(object): clause &= arg for key, value in params.iteritems(): - (keys, prop) = self._locate_prop(key) + (keys, prop) = self._locate_prop(key, start=start) c = prop.compare(value) & self.join_via(keys) if clause is None: clause = c @@ -265,7 +271,7 @@ class Query(object): if self._nestable(**kwargs): s = sql.select([self.table], whereclause, **kwargs).alias('getcount').count() else: - primary_key = self.mapper.pks_by_table[self.table] + primary_key = self.primary_key_columns s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **kwargs) return self.session.scalar(self.mapper, s, params=params) @@ -317,7 +323,7 @@ class Query(object): session = self.session - context = SelectionContext(self.mapper, session, with_options=self.with_options, **kwargs) + context = SelectionContext(self.select_mapper, session, with_options=self.with_options, **kwargs) result = util.UniqueAppender([]) if mappers: @@ -326,7 +332,7 @@ class Query(object): otherresults.append(util.UniqueAppender([])) for row in cursor.fetchall(): - self.mapper._instance(context, row, result) + self.select_mapper._instance(context, row, result) i = 0 for m in mappers: m._instance(context, row, otherresults[i]) @@ -356,7 +362,7 @@ class Query(object): ident = util.to_list(ident) i = 0 params = {} - for primary_key in self.mapper.pks_by_table[self.table]: + for primary_key in self.primary_key_columns: params[primary_key._label] = ident[i] # if there are not enough elements in the given identifier, then # use the previous identifier repeatedly. this is a workaround for the issue @@ -392,6 +398,16 @@ class Query(object): def compile(self, whereclause = None, **kwargs): """given a WHERE criterion, produce a ClauseElement-based statement suitable for usage in the execute() method.""" + + if whereclause is not None and self.is_polymorphic: + # adapt the given WHERECLAUSE to adjust instances of this query's mapped table to be that of our select_table, + # which may be the "polymorphic" selectable used by our mapper. + print "PolYMORPHIC YES" + print "WHERECLAUSE", str(whereclause) + print "OUR TABLE", str(self.table) + whereclause.accept_visitor(sql_util.ClauseAdapter(self.table)) + print "AND NOW ITS", str(whereclause) + context = kwargs.pop('query_context', None) if context is None: context = QueryContext(self, kwargs) @@ -412,8 +428,8 @@ class Query(object): except KeyError: raise exceptions.ArgumentError("Unknown lockmode '%s'" % lockmode) - if self.mapper.single and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None: - whereclause = sql.and_(whereclause, self.mapper.polymorphic_on.in_(*[m.polymorphic_identity for m in self.mapper.polymorphic_iterator()])) + if self.select_mapper.single and self.select_mapper.polymorphic_on is not None and self.select_mapper.polymorphic_identity is not None: + whereclause = sql.and_(whereclause, self.select_mapper.polymorphic_on.in_(*[m.polymorphic_identity for m in self.select_mapper.polymorphic_iterator()])) alltables = [] for l in [sql_util.TableFinder(x) for x in from_obj]: @@ -460,7 +476,9 @@ class Query(object): context.statement = statement # give all the attached properties a chance to modify the query - for value in self.mapper.props.values(): + # TODO: doing this off the select_mapper. if its the polymorphic mapper, then + # it has no relations() on it. should we compile those too into the query ? (i.e. eagerloads) + for value in self.select_mapper.props.values(): value.setup(context) return statement diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 863f2e65d3..d8d9a9c470 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -125,20 +125,8 @@ class DeferredOption(StrategizedOption): class AbstractRelationLoader(LoaderStrategy): def init(self): super(AbstractRelationLoader, self).init() - self.primaryjoin = self.parent_property.primaryjoin - self.secondaryjoin = self.parent_property.secondaryjoin - self.secondary = self.parent_property.secondary - self.foreignkey = self.parent_property.foreignkey - self.mapper = self.parent_property.mapper - self.select_mapper = self.mapper.get_select_mapper() - self.target = self.parent_property.target - self.select_table = self.parent_property.mapper.select_table - self.loads_polymorphic = self.target is not self.select_table - self.uselist = self.parent_property.uselist - self.cascade = self.parent_property.cascade - self.attributeext = self.parent_property.attributeext - self.order_by = self.parent_property.order_by - self.remote_side = self.parent_property.remote_side + for attr in ['primaryjoin', 'secondaryjoin', 'secondary', 'foreignkey', 'mapper', 'select_mapper', 'target', 'select_table', 'loads_polymorphic', 'uselist', 'cascade', 'attributeext', 'order_by', 'remote_side', 'polymorphic_primaryjoin', 'polymorphic_secondaryjoin', 'direction']: + setattr(self, attr, getattr(self.parent_property, attr)) self._should_log_debug = logging.is_debug_enabled(self.logger) def _init_instance_attribute(self, instance, callable_=None): @@ -163,7 +151,14 @@ NoLoader.logger = logging.class_logger(NoLoader) class LazyLoader(AbstractRelationLoader): def init(self): super(LazyLoader, self).init() - (self.lazywhere, self.lazybinds, self.lazyreverse) = self._create_lazy_clause(self.parent.unjoined_table, self.primaryjoin, self.secondaryjoin, self.foreignkey, self.remote_side, self.mapper.select_table) + (self.lazywhere, self.lazybinds, self.lazyreverse) = self._create_lazy_clause( + self.parent.select_table, + self.mapper.select_table, + self.polymorphic_primaryjoin, + self.polymorphic_secondaryjoin, + self.foreignkey, + self.remote_side) + # determine if our "lazywhere" clause is the same as the mapper's # get() clause. then we can just use mapper.get() self.use_get = not self.uselist and query.Query(self.mapper)._get_clause.compare(self.lazywhere) @@ -210,7 +205,7 @@ class LazyLoader(AbstractRelationLoader): # to possibly save a DB round trip if self.use_get: ident = [] - for primary_key in self.mapper.pks_by_table[self.mapper.mapped_table]: + for primary_key in self.select_mapper.pks_by_table[self.select_mapper.mapped_table]: bind = self.lazyreverse[primary_key] ident.append(params[bind.key]) return session.query(self.mapper).get(ident) @@ -247,11 +242,49 @@ class LazyLoader(AbstractRelationLoader): # to load data into it. sessionlib.attribute_manager.reset_instance_attribute(instance, self.key) - def _create_lazy_clause(self, table, primaryjoin, secondaryjoin, foreignkey, remote_side, select_table): + def _create_lazy_clause(self, parenttable, targettable, primaryjoin, secondaryjoin, foreignkey, remote_side): binds = {} reverse = {} - def column_in_table(table, column): - return table.corresponding_column(column, raiseerr=False, keys_ok=False) is not None + + #print "PARENTTABLE", parenttable, "TARGETTABLE", targettable + + def should_bind(targetcol, othercol): + # determine if the given target column is part of the parent table + # portion of the join condition, in which case it gets converted + # to a bind param. + + # contains_column will return if this column is exactly in the table, with no + # proxying relationships. the table can be either the column's actual parent table, + # or a Join object containing the table. for a Select, Alias, or Union, the column + # needs to be the actual ColumnElement exported by that selectable, not the "originating" column. + inparent = parenttable.c.contains_column(targetcol) + + # check if its also in the target table. if this is a many-to-many relationship, + # then we dont care about target table presence + intarget = secondaryjoin is None and targettable.c.contains_column(targetcol) + + if inparent and not intarget: + # its in the parent and not the target, return true. + return True + elif inparent and intarget: + # its in both. hmm. + if parenttable is not targettable: + # the column is in both tables, but the two tables are different. + # this corresponds to a table relating to a Join which also contains that table. + # such as tableA.c.col1 == tableB.c.col2, tables are tableA and tableA.join(tableB) + # in which case we only accept that the parenttable is the "base" table, not the "joined" table + return targetcol.table is parenttable + else: + # parent/target are the same table, i.e. circular reference. + # we have to rely on the "remote_side" argument + # and/or foreignkey collection. + # technically we can use this for the non-circular refs as well except that "remote_side" is usually + # only calculated for self-referential relationships at the moment. + # TODO: have PropertyLoader calculate remote_side completely ? this would involve moving most of the + # "should_bind()" logic to PropertyLoader. remote_side could also then be accurately used by sync.py. + if col_in_collection(othercol, remote_side): + return True + return False if remote_side is None or len(remote_side) == 0: remote_side = foreignkey @@ -280,14 +313,15 @@ class LazyLoader(AbstractRelationLoader): rightcol = find_column_in_expr(binary.right) if leftcol is None or rightcol is None: return - circular = leftcol.table is rightcol.table - if ((not circular and column_in_table(table, leftcol)) or (circular and col_in_collection(rightcol, remote_side))): + if should_bind(leftcol, rightcol): col = leftcol binary.left = binds.setdefault(leftcol, sql.bindparam(bind_label(), None, shortname=leftcol.name, type=binary.right.type)) reverse[rightcol] = binds[col] - if (leftcol is not rightcol) and ((not circular and column_in_table(table, rightcol)) or (circular and col_in_collection(leftcol, remote_side))): + # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1", + # which can happen in rare cases + if leftcol is not rightcol and should_bind(rightcol, leftcol): col = rightcol binary.right = binds.setdefault(rightcol, sql.bindparam(bind_label(), None, shortname=rightcol.name, type=binary.left.type)) @@ -299,13 +333,9 @@ class LazyLoader(AbstractRelationLoader): if secondaryjoin is not None: secondaryjoin = secondaryjoin.copy_container() - if self.loads_polymorphic: - secondaryjoin.accept_visitor(sql_util.ClauseAdapter(select_table)) lazywhere = sql.and_(lazywhere, secondaryjoin) - else: - if self.loads_polymorphic: - lazywhere.accept_visitor(sql_util.ClauseAdapter(select_table)) - + + #print "LAZYCLAUSE", str(lazywhere) LazyLoader.logger.info("create_lazy_clause " + str(lazywhere)) return (lazywhere, binds, reverse) @@ -317,7 +347,7 @@ class EagerLoader(AbstractRelationLoader): """loads related objects inline with a parent query.""" def init(self): super(EagerLoader, self).init() - if self.parent.isa(self.select_mapper): + if self.parent.isa(self.mapper): raise exceptions.ArgumentError("Error creating eager relationship '%s' on parent class '%s' to child class '%s': Cant use eager loading on a self referential relationship." % (self.key, repr(self.parent.class_), repr(self.mapper.class_))) self.parent._eager_loaders.add(self.parent_property) @@ -364,16 +394,12 @@ class EagerLoader(AbstractRelationLoader): eagerloader.target:self.eagertarget, eagerloader.secondary:self.eagersecondary }) - self.eagersecondaryjoin = eagerloader.secondaryjoin.copy_container() - if eagerloader.loads_polymorphic: - self.eagersecondaryjoin.accept_visitor(sql_util.ClauseAdapter(eagerloader.select_table)) + self.eagersecondaryjoin = eagerloader.polymorphic_secondaryjoin.copy_container() self.eagersecondaryjoin.accept_visitor(self.aliasizer) - self.eagerprimary = eagerloader.primaryjoin.copy_container() + self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container() self.eagerprimary.accept_visitor(self.aliasizer) else: - self.eagerprimary = eagerloader.primaryjoin.copy_container() - if eagerloader.loads_polymorphic: - self.eagerprimary.accept_visitor(sql_util.ClauseAdapter(eagerloader.select_table)) + self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container() self.aliasizer = sql_util.Aliasizer(self.target, aliases={self.target:self.eagertarget}) self.eagerprimary.accept_visitor(self.aliasizer) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 8b36ba0264..2be808ba61 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -65,8 +65,7 @@ class SchemaItem(object): a local non-None value overrides all others. after that, the parent item (i.e. Column for a Sequence, Table for a Column, MetaData for a Table) is searched for a non-None setting, traversing each parent until none are found. - finally, case_sensitive is set to True if and only if the name of this item - is not all lowercase. + finally, case_sensitive is set to True as a default. """ local = getattr(self, '_%s_setting' % keyname, None) if local is not None: @@ -78,7 +77,7 @@ class SchemaItem(object): parentval = getattr(parent, '_case_sensitive_setting', None) if parentval is not None: return parentval - return name is not None and name.lower() != name + return True def _get_case_sensitive(self): try: return self.__case_sensitive @@ -194,11 +193,9 @@ class Table(SchemaItem, sql.TableClause): quote_schema=False : indicates that the Namespace identifier must be properly escaped and quoted before being sent to the database. This flag overrides all other quoting behavior. - case_sensitive=True : indicates that the identifier should be interpreted by the database in the natural case for identifiers. - Mixed case is not sufficient to cause this identifier to be quoted; it must contain an illegal character. + case_sensitive=True : indicates quoting should be used if the identifier needs it. - case_sensitive_schema=True : indicates that the identifier should be interpreted by the database in the natural case for identifiers. - Mixed case is not sufficient to cause this identifier to be quoted; it must contain an illegal character. + case_sensitive_schema=True : indicates quoting should be used if the identifier needs it. """ super(Table, self).__init__(name) self._metadata = metadata @@ -365,8 +362,7 @@ class Column(SchemaItem, sql._ColumnClause): to the database. This flag should normally not be required as dialects can auto-detect conditions where quoting is required. - case_sensitive=True : indicates that the identifier should be interpreted by the database in the natural case for identifiers. - Mixed case is not sufficient to cause this identifier to be quoted; it must contain an illegal character. + case_sensitive=True : indicates quoting should be used if the identifier needs it. """ name = str(name) # in case of incoming unicode super(Column, self).__init__(name, None, type) diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 5598472615..5f392a61cb 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -694,7 +694,12 @@ class ColumnCollection(util.OrderedProperties): if c.shares_lineage(local): l.append(c==local) return and_(*l) - + def contains_column(self, col): + # have to use a Set here, because it will compare the identity + # of the column, not just using "==" for comparison which will always return a + # "True" value (i.e. a BinaryClause...) + return col in util.Set(self) + class FromClause(Selectable): """represents an element that can be used within the FROM clause of a SELECT statement.""" def __init__(self, name=None): @@ -1400,6 +1405,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause): for c in s.c: yield c def _proxy_column(self, column): + print "PROXYING COLUMN", type(column), column if self.use_labels: col = column._make_proxy(self, name=column._label) else: diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py index 6b87a2dec3..10d4495d93 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql_util.py @@ -135,17 +135,31 @@ class ClauseAdapter(sql.ClauseVisitor): s.c.col1 == table2.c.col1 """ - def __init__(self, selectable): + def __init__(self, selectable, include=None, exclude=None, equivalents=None): self.selectable = selectable + self.include = include + self.exclude = exclude + self.equivalents = equivalents + def include_col(self, col): + if not isinstance(col, sql.ColumnElement): + return None + if self.include is not None: + if col not in self.include: + return None + if self.exclude is not None: + if col in self.exclude: + return None + newcol = self.selectable.corresponding_column(col, raiseerr=False, keys_ok=False) + if newcol is None and self.equivalents is not None and col in self.equivalents: + newcol = self.selectable.corresponding_column(self.equivalents[col], raiseerr=False, keys_ok=False) + return newcol def visit_binary(self, binary): - if isinstance(binary.left, sql.ColumnElement): - col = self.selectable.corresponding_column(binary.left, raiseerr=False, keys_ok=True) - if col is not None: - binary.left = col - if isinstance(binary.right, sql.ColumnElement): - col = self.selectable.corresponding_column(binary.right, raiseerr=False, keys_ok=True) - if col is not None: - binary.right = col + col = self.include_col(binary.left) + if col is not None: + binary.left = col + col = self.include_col(binary.right) + if col is not None: + binary.right = col class ColumnsInClause(sql.ClauseVisitor): """given a selectable, visits clauses and determines if any columns from the clause are in the selectable""" diff --git a/test/orm/inheritance5.py b/test/orm/inheritance5.py index ab948c0035..49eca2fc3c 100644 --- a/test/orm/inheritance5.py +++ b/test/orm/inheritance5.py @@ -1,5 +1,6 @@ from sqlalchemy import * import testbase +from sqlalchemy.ext.selectresults import SelectResults class AttrSettable(object): def __init__(self, **kwargs): @@ -25,7 +26,11 @@ class RelationTest1(testbase.ORMTest): Column('manager_name', String(50)) ) - def testbasic(self): + def tearDown(self): + people.update(values={people.c.manager_id:None}).execute() + super(RelationTest1, self).tearDown() + + def testparentrefsdescendant(self): class Person(AttrSettable): pass class Manager(Person): @@ -59,11 +64,35 @@ class RelationTest1(testbase.ORMTest): m = session.query(Manager).get(m.person_id) print p, m, p.manager assert p.manager is m + + def testdescendantrefsparent(self): + class Person(AttrSettable): + pass + class Manager(Person): + pass + + mapper(Person, people) + mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, properties={ + 'employee':relation(Person, primaryjoin=people.c.manager_id==managers.c.person_id, foreignkey=people.c.manager_id, uselist=False, post_update=True) + }) + + session = create_session() + p = Person(name='some person') + m = Manager(name='some manager') + m.employee = p + session.save(m) + session.flush() + session.clear() + + p = session.query(Person).get(p.person_id) + m = session.query(Manager).get(m.person_id) + print p, m, m.employee + assert m.employee is p class RelationTest2(testbase.ORMTest): """test self-referential relationships on polymorphic mappers""" def define_tables(self, metadata): - global people, managers + global people, managers, data people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), Column('name', String(50)), @@ -74,28 +103,66 @@ class RelationTest2(testbase.ORMTest): Column('manager_id', Integer, ForeignKey('people.person_id')), Column('status', String(30)), ) - - def testrelationonsubclass(self): + + data = Table('data', metadata, + Column('person_id', Integer, ForeignKey('managers.person_id'), primary_key=True), + Column('data', String(30)) + ) + + def testrelationonsubclass_j1_nodata(self): + self.do_test("join1", False) + def testrelationonsubclass_j2_nodata(self): + self.do_test("join2", False) + def testrelationonsubclass_j1_data(self): + self.do_test("join1", True) + def testrelationonsubclass_j2_data(self): + self.do_test("join2", True) + + def do_test(self, jointype="join1", usedata=False): class Person(AttrSettable): pass class Manager(Person): pass - poly_union = polymorphic_union({ - 'person':people.select(people.c.type=='person'), - 'manager':managers.join(people, people.c.person_id==managers.c.person_id) - }, None) + if jointype == "join1": + poly_union = polymorphic_union({ + 'person':people.select(people.c.type=='person'), + 'manager':join(people, managers, people.c.person_id==managers.c.person_id) + }, None) + elif jointype == "join2": + poly_union = polymorphic_union({ + 'person':people.select(people.c.type=='person'), + 'manager':managers.join(people, people.c.person_id==managers.c.person_id) + }, None) + + if usedata: + class Data(object): + def __init__(self, data): + self.data = data + mapper(Data, data) + + mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=poly_union.c.type) + + if usedata: + mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager', + properties={ + 'colleague':relation(Person, primaryjoin=managers.c.manager_id==people.c.person_id, lazy=True, uselist=False), + 'data':relation(Data, uselist=False) + } + ) + else: + mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager', + properties={ + 'colleague':relation(Person, primaryjoin=managers.c.manager_id==people.c.person_id, lazy=True, uselist=False) + } + ) - mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=people.c.type) - mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager', - properties={ - 'colleague':relation(Person, primaryjoin=managers.c.manager_id==people.c.person_id, uselist=False) - }) - class_mapper(Person).compile() sess = create_session() p = Person(name='person1') m = Manager(name='manager1') m.colleague = p + if usedata: + m.data = Data('ms data') sess.save(m) sess.flush() @@ -105,11 +172,13 @@ class RelationTest2(testbase.ORMTest): print p print m assert m.colleague is p + if usedata: + assert m.data.data == 'ms data' class RelationTest3(testbase.ORMTest): """test self-referential relationships on polymorphic mappers""" def define_tables(self, metadata): - global people, managers + global people, managers, data people = Table('people', metadata, Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), Column('colleague_id', Integer, ForeignKey('people.person_id')), @@ -121,24 +190,60 @@ class RelationTest3(testbase.ORMTest): Column('status', String(30)), ) - def testrelationonbaseclass(self): + data = Table('data', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('data', String(30)) + ) + + def testrelationonbaseclass_j1_nodata(self): + self.do_test("join1", False) + def testrelationonbaseclass_j2_nodata(self): + self.do_test("join2", False) + def testrelationonbaseclass_j1_data(self): + self.do_test("join1", True) + def testrelationonbaseclass_j2_data(self): + self.do_test("join2", True) + + def do_test(self, jointype="join1", usedata=False): class Person(AttrSettable): pass class Manager(Person): pass - poly_union = polymorphic_union({ - 'manager':managers.join(people, people.c.person_id==managers.c.person_id), - 'person':people.select(people.c.type=='person') - }, None) - - mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=people.c.type, - properties={ - 'colleagues':relation(Person, primaryjoin=people.c.colleague_id==people.c.person_id, - remote_side=people.c.person_id, uselist=True) - } - ) + if usedata: + class Data(object): + def __init__(self, data): + self.data = data + + if jointype == "join1": + poly_union = polymorphic_union({ + 'manager':managers.join(people, people.c.person_id==managers.c.person_id), + 'person':people.select(people.c.type=='person') + }, None) + elif jointype =="join2": + poly_union = polymorphic_union({ + 'manager':join(people, managers, people.c.person_id==managers.c.person_id), + 'person':people.select(people.c.type=='person') + }, None) + + if usedata: + mapper(Data, data) + mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id, polymorphic_identity='manager') + if usedata: + mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=people.c.type, + properties={ + 'colleagues':relation(Person, primaryjoin=people.c.colleague_id==people.c.person_id, remote_side=people.c.colleague_id, uselist=True), + 'data':relation(Data, uselist=False) + } + ) + else: + mapper(Person, people, select_table=poly_union, polymorphic_identity='person', polymorphic_on=people.c.type, + properties={ + 'colleagues':relation(Person, primaryjoin=people.c.colleague_id==people.c.person_id, + remote_side=people.c.colleague_id, uselist=True) + } + ) sess = create_session() p = Person(name='person1') @@ -146,6 +251,10 @@ class RelationTest3(testbase.ORMTest): m = Manager(name='manager1') p.colleagues.append(p2) m.colleagues.append(p2) + if usedata: + p.data = Data('ps data') + m.data = Data('ms data') + sess.save(m) sess.save(p) sess.flush() @@ -156,7 +265,11 @@ class RelationTest3(testbase.ORMTest): print p, p2, p.colleagues assert len(p.colleagues) == 1 assert p.colleagues == [p2] + if usedata: + assert p.data.data == 'ps data' + assert m.data.data == 'ms data' + class RelationTest4(testbase.ORMTest): def define_tables(self, metadata): global people, engineers, managers, cars @@ -257,6 +370,11 @@ class RelationTest4(testbase.ORMTest): car1 = session.query(Car).options(eagerload('employee')).get(car1.car_id) assert str(car1.employee) == "Engineer E4, status X" + session.clear() + s = SelectResults(session.query(Car)) + c = s.join_to("employee").select(employee_join.c.name=="E4")[0] + assert c.car_id==car1.car_id + class RelationTest5(testbase.ORMTest): def define_tables(self, metadata): global people, engineers, managers, cars @@ -317,6 +435,125 @@ class RelationTest5(testbase.ORMTest): assert carlist[0].manager is None assert carlist[1].manager.person_id == car2.manager.person_id +class SelectResultsTest(testbase.AssertMixin): + def setUpAll(self): + # cars---owned by--- people (abstract) --- has a --- status + # | ^ ^ | + # | | | | + # | engineers managers | + # | | + # +--------------------------------------- has a ------+ + + global metadata, status, people, engineers, managers, cars + metadata = BoundMetaData(testbase.db) + # table definitions + status = Table('status', metadata, + Column('status_id', Integer, primary_key=True), + Column('name', String(20))) + + people = Table('people', metadata, + Column('person_id', Integer, primary_key=True), + Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False), + Column('name', String(50))) + + engineers = Table('engineers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('field', String(30))) + + managers = Table('managers', metadata, + Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True), + Column('category', String(70))) + + cars = Table('cars', metadata, + Column('car_id', Integer, primary_key=True), + Column('status_id', Integer, ForeignKey('status.status_id'), nullable=False), + Column('owner', Integer, ForeignKey('people.person_id'), nullable=False)) + + metadata.create_all() + + def tearDownAll(self): + metadata.drop_all() + def tearDown(self): + clear_mappers() + for t in metadata.table_iterator(reverse=True): + t.delete().execute() + + def testjointo(self): + # class definitions + class PersistentObject(object): + def __init__(self, **kwargs): + for key, value in kwargs.iteritems(): + setattr(self, key, value) + class Status(PersistentObject): + def __repr__(self): + return "Status %s" % self.name + class Person(PersistentObject): + def __repr__(self): + return "Ordinary person %s" % self.name + class Engineer(Person): + def __repr__(self): + return "Engineer %s, field %s, status %s" % (self.name, self.field, self.status) + class Manager(Person): + def __repr__(self): + return "Manager %s, category %s, status %s" % (self.name, self.category, self.status) + class Car(PersistentObject): + def __repr__(self): + return "Car number %d" % self.car_id + + # create a union that represents both types of joins. + employee_join = polymorphic_union( + { + 'engineer':people.join(engineers), + 'manager':people.join(managers), + }, "type", 'employee_join') + + status_mapper = mapper(Status, status) + person_mapper = mapper(Person, people, + select_table=employee_join,polymorphic_on=employee_join.c.type, + polymorphic_identity='person', properties={'status':relation(status_mapper)}) + engineer_mapper = mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') + manager_mapper = mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') + car_mapper = mapper(Car, cars, properties= {'employee':relation(person_mapper), 'status':relation(status_mapper)}) + + session = create_session(echo_uow=False) + + active = Status(name="active") + dead = Status(name="dead") + + session.save(active) + session.save(dead) + session.flush() + + # creating 5 managers named from M1 to M5 and 5 engineers named from E1 to E5 + # M4, M5, E4 and E5 are dead + for i in range(1,5): + if i<4: + st=active + else: + st=dead + session.save(Manager(name="M%d" % i,category="YYYYYYYYY",status=st)) + session.save(Engineer(name="E%d" % i,field="X",status=st)) + + session.flush() + + # get E4 + engineer4 = session.query(engineer_mapper).get_by(name="E4") + + # create 2 cars for E4, one active and one dead + car1 = Car(employee=engineer4,status=active) + car2 = Car(employee=engineer4,status=dead) + session.save(car1) + session.save(car2) + session.flush() + +# for activeCars in SelectResults(session.query(Car)).join_to('status').select(status.c.name=="active"): +# print activeCars + for activePerson in SelectResults(session.query(Person)).join_to('status').select(status.c.name=="active"): + print activePerson +# for activePerson in SelectResults(session.query(Person)).join_to('status').select_by(name="active"): +# print activePerson + + class MultiLevelTest(testbase.ORMTest): def define_tables(self, metadata): global table_Employee, table_Engineer, table_Manager @@ -344,11 +581,18 @@ class MultiLevelTest(testbase.ORMTest): __repr__ = __str__ class Engineer( Employee): pass class Manager( Engineer): pass + pu_Employee = polymorphic_union( { 'Manager': table_Employee.join( table_Engineer).join( table_Manager), 'Engineer': select([table_Employee, table_Engineer.c.machine], table_Employee.c.atype == 'Engineer', from_obj=[table_Employee.join(table_Engineer)]), 'Employee': table_Employee.select( table_Employee.c.atype == 'Employee'), }, None, 'pu_employee', ) + +# pu_Employee = polymorphic_union( { +# 'Manager': table_Employee.join( table_Engineer).join( table_Manager), +# 'Engineer': table_Employee.join(table_Engineer).select(table_Employee.c.atype == 'Engineer'), +# 'Employee': table_Employee.select( table_Employee.c.atype == 'Employee'), +# }, None, 'pu_employee', ) mapper_Employee = mapper( Employee, table_Employee, polymorphic_identity= 'Employee', @@ -389,4 +633,4 @@ class MultiLevelTest(testbase.ORMTest): if __name__ == "__main__": testbase.main() - \ No newline at end of file + diff --git a/test/orm/poly_linked_list.py b/test/orm/poly_linked_list.py index adf844c5c6..2f4ee96ff0 100644 --- a/test/orm/poly_linked_list.py +++ b/test/orm/poly_linked_list.py @@ -1,12 +1,10 @@ import testbase from sqlalchemy import * -class PolymorphicCircularTest(testbase.PersistTest): - def setUpAll(self): - global metadata +class PolymorphicCircularTest(testbase.ORMTest): + keep_mappers = True + def define_tables(self, metadata): global Table1, Table1B, Table2, Table3, Data - metadata = BoundMetaData(testbase.db) - table1 = Table('table1', metadata, Column('id', Integer, primary_key=True), Column('related_id', Integer, ForeignKey('table1.id'), nullable=True), @@ -28,8 +26,6 @@ class PolymorphicCircularTest(testbase.PersistTest): Column('data', String(30)) ) - metadata.create_all() - join = polymorphic_union( { 'table3' : table1.join(table3), @@ -61,7 +57,7 @@ class PolymorphicCircularTest(testbase.PersistTest): self.data = data def __repr__(self): return "%s(%d, %s)" % (self.__class__.__name__, self.id, repr(str(self.data))) - + try: # this is how the mapping used to work. insure that this raises an error now table1_mapper = mapper(Table1, table1, @@ -71,8 +67,8 @@ class PolymorphicCircularTest(testbase.PersistTest): properties={ 'next': relation(Table1, backref=backref('prev', primaryjoin=join.c.id==join.c.related_id, foreignkey=join.c.id, uselist=False), - uselist=False, lazy=False, primaryjoin=join.c.id==join.c.related_id), - 'data':relation(mapper(Data, data), lazy=False) + uselist=False, primaryjoin=join.c.id==join.c.related_id), + 'data':relation(mapper(Data, data), lazy=lazy) } ) table1_mapper.compile() @@ -81,8 +77,11 @@ class PolymorphicCircularTest(testbase.PersistTest): assert True clear_mappers() - # currently, all of these "eager" relationships degrade to lazy relationships + # currently, the "eager" relationships degrade to lazy relationships # due to the polymorphic load. + # the "next" relation used to have a "lazy=False" on it, but the EagerLoader raises the "self-referential" + # exception now. since eager loading would never work for that relation anyway, its better that the user + # gets an exception instead of it silently not eager loading. table1_mapper = mapper(Table1, table1, select_table=join, polymorphic_on=join.c.type, @@ -90,12 +89,10 @@ class PolymorphicCircularTest(testbase.PersistTest): properties={ 'next': relation(Table1, backref=backref('prev', primaryjoin=table1.c.id==table1.c.related_id, remote_side=table1.c.id, uselist=False), - uselist=False, lazy=False, primaryjoin=table1.c.id==table1.c.related_id), + uselist=False, primaryjoin=table1.c.id==table1.c.related_id), 'data':relation(mapper(Data, data), lazy=False) } ) - - table1b_mapper = mapper(Table1B, inherits=table1_mapper, polymorphic_identity='table1b') @@ -104,13 +101,6 @@ class PolymorphicCircularTest(testbase.PersistTest): polymorphic_identity='table2') table3_mapper = mapper(Table3, table3, inherits=table1_mapper, polymorphic_identity='table3') - def tearDown(self): - for t in metadata.table_iterator(reverse=True): - t.delete().execute() - - def tearDownAll(self): - clear_mappers() - metadata.drop_all() def testone(self): self.do_testlist([Table1, Table2, Table1, Table2]) @@ -188,9 +178,9 @@ class PolymorphicCircularTest(testbase.PersistTest): backwards = repr(assertlist) # everything should match ! - print original - print backwards - print forwards + print "ORIGNAL", original + print "BACKWARDS",backwards + print "FORWARDS", forwards assert original == forwards == backwards if __name__ == '__main__': diff --git a/test/orm/polymorph.py b/test/orm/polymorph.py index 2c66fc7fc5..b0598e9af1 100644 --- a/test/orm/polymorph.py +++ b/test/orm/polymorph.py @@ -234,10 +234,24 @@ class MultipleTableTest(testbase.PersistTest): print "\n" - dilbert = session.query(Person).selectfirst(person_join.c.name=='dilbert') + # test selecting from the query, using the base mapped table (people) as the selection criterion. + # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join" + dilbert = session.query(Person).selectfirst(people.c.name=='dilbert') dilbert2 = session.query(Engineer).selectfirst(people.c.name=='dilbert') assert dilbert is dilbert2 + # test selecting from the query, joining against an alias of the base "people" table. test that + # the "palias" alias does *not* get sucked up into the "person_join" conversion. + palias = people.alias("palias") + session.query(Person).selectfirst((palias.c.name=='dilbert') & (palias.c.person_id==people.c.person_id)) + dilbert2 = session.query(Engineer).selectfirst((palias.c.name=='dilbert') & (palias.c.person_id==people.c.person_id)) + assert dilbert is dilbert2 + + session.query(Person).selectfirst((engineers.c.engineer_name=="engineer1") & (engineers.c.person_id==people.c.person_id)) + dilbert2 = session.query(Engineer).selectfirst(engineers.c.engineer_name=="engineer1") + assert dilbert is dilbert2 + + dilbert.engineer_name = 'hes dibert!' session.flush() diff --git a/test/orm/relationships.py b/test/orm/relationships.py index 250f5ab0d5..d7fff287f8 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -173,7 +173,7 @@ class RelationTest2(testbase.PersistTest): assert sess.query(Employee).get([c1.company_id, 3]).reports_to.name == 'emp1' assert sess.query(Employee).get([c2.company_id, 3]).reports_to.name == 'emp5' - def testimplict(self): + def testimplicit(self): """test with mappers that have the most minimal arguments""" class Company(object): pass diff --git a/test/sql/quote.py b/test/sql/quote.py index 6438586f6b..8ae228031f 100644 --- a/test/sql/quote.py +++ b/test/sql/quote.py @@ -89,6 +89,13 @@ class QuoteTest(PersistTest): where the "UPPERCASE" column of "LaLa" doesnt exist. """ x = table1.select(distinct=True).alias("LaLa").select().scalar() + + def testlabels2(self): + metadata = MetaData() + table = Table("ImATable", metadata, + Column("col1", Integer)) + x = select([table.c.col1.label("ImATable_col1")]).alias("SomeAlias") + assert str(select([x.c.ImATable_col1])) == '''SELECT "SomeAlias"."ImATable_col1" \nFROM (SELECT "ImATable".col1 AS "ImATable_col1" \nFROM "ImATable") AS "SomeAlias"''' def testlabelsnocase(self): metadata = MetaData() diff --git a/test/testbase.py b/test/testbase.py index 7d981509d1..bc5153af0e 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -205,6 +205,8 @@ class AssertMixin(PersistTest): self.assert_(db.sql_count == count, "desired statement count %d does not match %d" % (count, db.sql_count)) class ORMTest(AssertMixin): + keep_mappers = False + keep_data = False def setUpAll(self): global metadata metadata = BoundMetaData(db) @@ -217,10 +219,11 @@ class ORMTest(AssertMixin): def tearDownAll(self): metadata.drop_all() def tearDown(self): - clear_mappers() - for t in metadata.table_iterator(reverse=True): - t.delete().execute().close() - + if not self.keep_mappers: + clear_mappers() + if not self.keep_data: + for t in metadata.table_iterator(reverse=True): + t.delete().execute().close() class EngineAssert(proxy.BaseProxyEngine): """decorates a SQLEngine object to match the incoming queries against a set of assertions."""