From: Mike Bayer Date: Mon, 3 Nov 2008 02:52:30 +0000 (+0000) Subject: - Improved the behavior of aliased() objects such that they more X-Git-Tag: rel_0_5rc3~16 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a5dfbeedb9f7ae148081d1dbc3e91e876526eb90;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Improved the behavior of aliased() objects such that they more accurately adapt the expressions generated, which helps particularly with self-referential comparisons. [ticket:1171] - Fixed bug involving primaryjoin/secondaryjoin conditions constructed from class-bound attributes (as often occurs when using declarative), which later would be inappropriately aliased by Query, particularly with the various EXISTS based comparators. --- diff --git a/CHANGES b/CHANGES index 54409f31b8..f463ea55e3 100644 --- a/CHANGES +++ b/CHANGES @@ -15,6 +15,16 @@ CHANGES to an iterable. Use contains() to test for collection membership. + - Improved the behavior of aliased() objects such that they more + accurately adapt the expressions generated, which helps + particularly with self-referential comparisons. [ticket:1171] + + - Fixed bug involving primaryjoin/secondaryjoin conditions + constructed from class-bound attributes (as often occurs + when using declarative), which later would be inappropriately + aliased by Query, particularly with the various EXISTS + based comparators. + - Improved weakref identity map memory management to no longer require mutexing, resurrects garbage collected instance on a lazy basis for an InstanceState with pending changes. diff --git a/examples/nested_sets/nested_sets.py b/examples/nested_sets/nested_sets.py index 8af9aed455..e83a263e93 100644 --- a/examples/nested_sets/nested_sets.py +++ b/examples/nested_sets/nested_sets.py @@ -84,10 +84,9 @@ session.commit() print session.query(Employee).all() # 1. Find an employee and all his/her supervisors, no matter how deep the tree. -# (the between() operator in SQLAlchemy has a bug here, [ticket:1171]) ealias = aliased(Employee) print session.query(Employee).\ - filter(ealias.left>=Employee.left).filter(ealias.left<=Employee.right).\ + filter(ealias.left.between(Employee.left, Employee.right)).\ filter(ealias.emp=='Eddie').all() #2. Find the employee and all his/her subordinates. (This query has a nice symmetry with the first query.) @@ -97,7 +96,7 @@ print session.query(Employee).\ #3. Find the level of each node, so you can print the tree as an indented listing. for indentation, employee in session.query(func.count(Employee.emp).label('indentation') - 1, ealias).\ - filter(ealias.left>=Employee.left).filter(ealias.left<=Employee.right).\ + filter(ealias.left.between(Employee.left, Employee.right)).\ group_by(ealias.emp).\ order_by(ealias.left): print " " * indentation + str(employee) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 631d3f5820..bd934ce13b 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -438,9 +438,21 @@ class PropComparator(expression.ColumnOperators): PropComparator. """ + def __init__(self, prop, mapper, adapter=None): + self.prop = self.property = prop + self.mapper = mapper + self.adapter = adapter + def __clause_element__(self): raise NotImplementedError("%r" % self) + def adapted(self, adapter): + """Return a copy of this PropComparator which will use the given adaption function + on the local side of generated expressions. + + """ + return self.__class__(self.prop, self.mapper, adapter) + @staticmethod def any_op(a, b, **kwargs): return a.any(b, **kwargs) @@ -449,10 +461,6 @@ class PropComparator(expression.ColumnOperators): def has_op(a, b, **kwargs): return a.has(b, **kwargs) - def __init__(self, prop, mapper): - self.prop = self.property = prop - self.mapper = mapper - @staticmethod def of_type_op(a, class_): return a.of_type(class_) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 87e35eb831..2b860af370 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -18,7 +18,7 @@ from sqlalchemy.sql import operators, expression from sqlalchemy.orm import ( attributes, dependency, mapper, object_mapper, strategies, ) -from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, _orm_annotate +from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, _orm_annotate, _orm_deannotate from sqlalchemy.orm.interfaces import ( MANYTOMANY, MANYTOONE, MapperProperty, ONETOMANY, PropComparator, StrategizedProperty, @@ -85,8 +85,11 @@ class ColumnProperty(StrategizedProperty): class ColumnComparator(PropComparator): @util.memoized_instancemethod def __clause_element__(self): - return self.prop.columns[0]._annotate({"parententity": self.mapper}) - + if self.adapter: + return self.adapter(self.prop.columns[0]) + else: + return self.prop.columns[0]._annotate({"parententity": self.mapper}) + def operate(self, op, *other, **kwargs): return op(self.__clause_element__(), *other, **kwargs) @@ -147,7 +150,11 @@ class CompositeProperty(ColumnProperty): class Comparator(PropComparator): def __clause_element__(self): - return expression.ClauseList(*self.prop.columns) + if self.adapter: + # TODO: test coverage for adapted composite comparison + return expression.ClauseList(*[self.adapter(x) for x in self.prop.columns]) + else: + return expression.ClauseList(*self.prop.columns) def __eq__(self, other): if other is None: @@ -318,18 +325,30 @@ class PropertyLoader(StrategizedProperty): self._is_backref = _is_backref class Comparator(PropComparator): - def __init__(self, prop, mapper, of_type=None): + def __init__(self, prop, mapper, of_type=None, adapter=None): self.prop = self.property = prop self.mapper = mapper + self.adapter = adapter if of_type: self._of_type = _class_to_mapper(of_type) + def adapted(self, adapter): + """Return a copy of this PropComparator which will use the given adaption function + on the local side of generated expressions. + + """ + return PropertyLoader.Comparator(self.prop, self.mapper, getattr(self, '_of_type', None), adapter) + @property def parententity(self): return self.prop.parent def __clause_element__(self): - return self.prop.parent._with_polymorphic_selectable + elem = self.prop.parent._with_polymorphic_selectable + if self.adapter: + return self.adapter(elem) + else: + return elem def operate(self, op, *other, **kwargs): return op(self, *other, **kwargs) @@ -343,13 +362,13 @@ class PropertyLoader(StrategizedProperty): def __eq__(self, other): if other is None: if self.prop.direction in [ONETOMANY, MANYTOMANY]: - return ~sql.exists([1], self.prop.primaryjoin) + return ~self._criterion_exists() else: - return self.prop._optimized_compare(None) + return self.prop._optimized_compare(None, adapt_source=self.adapter) elif self.prop.uselist: raise sa_exc.InvalidRequestError("Can't compare a collection to an object or collection; use contains() to test for membership.") else: - return self.prop._optimized_compare(other) + return self.prop._optimized_compare(other, adapt_source=self.adapter) def _criterion_exists(self, criterion=None, **kwargs): if getattr(self, '_of_type', None): @@ -360,7 +379,12 @@ class PropertyLoader(StrategizedProperty): else: to_selectable = None - pj, sj, source, dest, secondary, target_adapter = self.prop._create_joins(dest_polymorphic=True, dest_selectable=to_selectable) + if self.adapter: + source_selectable = self.__clause_element__() + else: + source_selectable = None + pj, sj, source, dest, secondary, target_adapter = \ + self.prop._create_joins(dest_polymorphic=True, dest_selectable=to_selectable, source_selectable=source_selectable) for k in kwargs: crit = self.prop.mapper.class_manager.get_inst(k) == kwargs[k] @@ -368,12 +392,15 @@ class PropertyLoader(StrategizedProperty): criterion = crit else: criterion = criterion & crit - + + # annotate the *local* side of the join condition, in the case of pj + sj this + # is the full primaryjoin, in the case of just pj its the local side of + # the primaryjoin. if sj: j = _orm_annotate(pj) & sj else: j = _orm_annotate(pj, exclude=self.prop.remote_side) - + if criterion and target_adapter: # limit this adapter to annotated only? criterion = target_adapter.traverse(criterion) @@ -383,7 +410,10 @@ class PropertyLoader(StrategizedProperty): # to anything in the enclosing query. if criterion: criterion = criterion._annotate({'_halt_adapt': True}) - return sql.exists([1], j & criterion, from_obj=dest).correlate(source) + + crit = j & criterion + + return sql.exists([1], crit, from_obj=dest).correlate(source) def any(self, criterion=None, **kwargs): if not self.prop.uselist: @@ -399,7 +429,7 @@ class PropertyLoader(StrategizedProperty): def contains(self, other, **kwargs): if not self.prop.uselist: raise sa_exc.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==") - clause = self.prop._optimized_compare(other) + clause = self.prop._optimized_compare(other, adapt_source=self.adapter) if self.prop.secondaryjoin: clause.negation_clause = self.__negated_contains_or_equals(other) @@ -410,12 +440,22 @@ class PropertyLoader(StrategizedProperty): if self.prop.direction == MANYTOONE: state = attributes.instance_state(other) strategy = self.prop._get_strategy(strategies.LazyLoader) + + def state_bindparam(state, col): + o = state.obj() # strong ref + return lambda: self.prop.mapper._get_committed_attr_by_column(o, col) + + def adapt(col): + if self.adapter: + return self.adapter(col) + else: + return col + if strategy.use_get: return sql.and_(*[ sql.or_( - x != - self.prop.mapper._get_committed_state_attr_by_column(state, y), - x == None) + adapt(x) != state_bindparam(state, y), + adapt(x) == None) for (x, y) in self.prop.local_remote_pairs]) criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]) @@ -444,10 +484,11 @@ class PropertyLoader(StrategizedProperty): else: return op(self.comparator, value) - def _optimized_compare(self, value, value_is_parent=False): + def _optimized_compare(self, value, value_is_parent=False, adapt_source=None): if value is not None: value = attributes.instance_state(value) - return self._get_strategy(strategies.LazyLoader).lazy_clause(value, reverse_direction=not value_is_parent, alias_secondary=True) + return self._get_strategy(strategies.LazyLoader).\ + lazy_clause(value, reverse_direction=not value_is_parent, alias_secondary=True, adapt_source=adapt_source) def __str__(self): return str(self.parent.class_.__name__) + "." + self.key @@ -549,6 +590,14 @@ class PropertyLoader(StrategizedProperty): for attr in ('order_by', 'primaryjoin', 'secondaryjoin', 'secondary', '_foreign_keys', 'remote_side'): if callable(getattr(self, attr)): setattr(self, attr, getattr(self, attr)()) + + # in the case that InstrumentedAttributes were used to construct + # primaryjoin or secondaryjoin, remove the "_orm_adapt" annotation so these + # interact with Query in the same way as the original Table-bound Column objects + for attr in ('primaryjoin', 'secondaryjoin'): + val = getattr(self, attr) + if val: + setattr(self, attr, _orm_deannotate(val)) if self.order_by: self.order_by = [expression._literal_as_column(x) for x in util.to_list(self.order_by)] diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 1962a7e2d9..ba5541944c 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -353,9 +353,9 @@ class LazyLoader(AbstractRelationLoader): self.is_class_level = True self._register_attribute(self.parent.class_, callable_=self.class_level_loader) - def lazy_clause(self, state, reverse_direction=False, alias_secondary=False): + def lazy_clause(self, state, reverse_direction=False, alias_secondary=False, adapt_source=None): if state is None: - return self._lazy_none_clause(reverse_direction) + return self._lazy_none_clause(reverse_direction, adapt_source=adapt_source) if not reverse_direction: (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns) @@ -374,9 +374,12 @@ class LazyLoader(AbstractRelationLoader): if self.parent_property.secondary and alias_secondary: criterion = sql_util.ClauseAdapter(self.parent_property.secondary.alias()).traverse(criterion) - return visitors.cloned_traverse(criterion, {}, {'bindparam':visit_bindparam}) - - def _lazy_none_clause(self, reverse_direction=False): + criterion = visitors.cloned_traverse(criterion, {}, {'bindparam':visit_bindparam}) + if adapt_source: + criterion = adapt_source(criterion) + return criterion + + def _lazy_none_clause(self, reverse_direction=False, adapt_source=None): if not reverse_direction: (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns) else: @@ -393,7 +396,10 @@ class LazyLoader(AbstractRelationLoader): binary.right = expression.null() binary.operator = operators.is_ - return visitors.cloned_traverse(criterion, {}, {'binary':visit_binary}) + criterion = visitors.cloned_traverse(criterion, {}, {'binary':visit_binary}) + if adapt_source: + criterion = adapt_source(criterion) + return criterion def class_level_loader(self, state, options=None, path=None): if not mapperutil._state_has_identity(state): diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 264a4d2125..fbc1acd5d1 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -243,6 +243,12 @@ class ExtensionCarrier(dict): return self.get(key, self._pass) class ORMAdapter(sql_util.ColumnAdapter): + """Extends ColumnAdapter to accept ORM entities. + + The selectable is extracted from the given entity, + and the AliasedClass if any is referenced. + + """ def __init__(self, entity, equivalents=None, chain_to=None): mapper, selectable, is_aliased_class = _entity_info(entity) if is_aliased_class: @@ -252,18 +258,36 @@ class ORMAdapter(sql_util.ColumnAdapter): sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to) class AliasedClass(object): + """Represents an 'alias'ed form of a mapped class for usage with Query. + + The ORM equivalent of a sqlalchemy.sql.expression.Alias + object, this object mimics the mapped class using a + __getattr__ scheme and maintains a reference to a + real Alias object. It indicates to Query that the + selectable produced for this class should be aliased, + and also adapts PropComparators produced by the class' + InstrumentedAttributes so that they adapt the + "local" side of SQL expressions against the alias. + + """ def __init__(self, cls, alias=None, name=None): self.__mapper = _class_to_mapper(cls) self.__target = self.__mapper.class_ alias = alias or self.__mapper._with_polymorphic_selectable.alias() self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns) self.__alias = alias + # used to assign a name to the RowTuple object + # returned by Query. self._sa_label_name = name self.__name__ = 'AliasedClass_' + str(self.__target) + def __adapt_element(self, elem): + return self.__adapter.traverse(elem)._annotate({'parententity': self}) + def __adapt_prop(self, prop): existing = getattr(self.__target, prop.key) - comparator = AliasedComparator(self, self.__adapter, existing.comparator) + comparator = existing.comparator.adapted(self.__adapt_element) + queryattr = attributes.QueryableAttribute( existing.impl, parententity=self, comparator=comparator) setattr(self, prop.key, queryattr) @@ -299,41 +323,16 @@ class AliasedClass(object): return '' % ( id(self), self.__target.__name__) -class AliasedComparator(PropComparator): - def __init__(self, aliasedclass, adapter, comparator): - self.aliasedclass = aliasedclass - self.comparator = comparator - self.adapter = adapter - self.__clause_element = self.adapter.traverse(self.comparator.__clause_element__())._annotate({'parententity': aliasedclass}) - - def __clause_element__(self): - return self.__clause_element - - def operate(self, op, *other, **kwargs): - return self.adapter.traverse(self.comparator.operate(op, *other, **kwargs)) - - def reverse_operate(self, op, other, **kwargs): - return self.adapter.traverse(self.comparator.reverse_operate(op, *other, **kwargs)) - def _orm_annotate(element, exclude=None): """Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag. Elements within the exclude collection will be cloned but not annotated. """ - def clone(elem): - if exclude and elem in exclude: - elem = elem._clone() - elif '_orm_adapt' not in elem._annotations: - elem = elem._annotate({'_orm_adapt':True}) - elem._copy_internals(clone=clone) - return elem - - if element is not None: - element = clone(element) - return element - + return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude) +_orm_deannotate = sql_util._deep_deannotate + class _ORMJoin(expression.Join): """Extend Join to support ORM constructs as input.""" diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index a9ef45dc1e..3b996d6cba 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -987,6 +987,13 @@ class ClauseElement(Visitable): @property def _cloned_set(self): + """Return the set consisting all cloned anscestors of this ClauseElement. + + Includes this ClauseElement. This accessor tends to be used for + FromClause objects to identify 'equivalent' FROM clauses, regardless + of transformative operations. + + """ f = self while f is not None: yield f @@ -1008,7 +1015,11 @@ class ClauseElement(Visitable): if Annotated is None: from sqlalchemy.sql.util import Annotated return Annotated(self, values) - + + def _deannotate(self): + """return a copy of this ClauseElement with an empty annotations dictionary.""" + return self._clone() + def unique_params(self, *optionaldict, **kwargs): """Return a copy with ``bindparam()`` elments replaced. diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index d9c3ed8998..2a510906b1 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -159,6 +159,20 @@ class Annotated(object): clone.__dict__ = self.__dict__.copy() clone._annotations = _values return clone + + def _deannotate(self): + return self.__element + + def _clone(self): + clone = self.__element._clone() + if clone is self.__element: + # detect immutable, don't change anything + return self + else: + # update the clone with any changes that have occured + # to this object's __dict__. + clone.__dict__.update(self.__dict__) + return Annotated(clone, self._annotations) def __hash__(self): return hash(self.__element) @@ -166,6 +180,39 @@ class Annotated(object): def __cmp__(self, other): return cmp(hash(self.__element), hash(other)) +def _deep_annotate(element, annotations, exclude=None): + """Deep copy the given ClauseElement, annotating each element with the given annotations dictionary. + + Elements within the exclude collection will be cloned but not annotated. + + """ + def clone(elem): + # check if element is present in the exclude list. + # take into account proxying relationships. + if exclude and elem.proxy_set.intersection(exclude): + elem = elem._clone() + elif annotations != elem._annotations: + elem = elem._annotate(annotations.copy()) + elem._copy_internals(clone=clone) + return elem + + if element is not None: + element = clone(element) + return element + +def _deep_deannotate(element): + """Deep copy the given element, removing all annotations.""" + + def clone(elem): + elem = elem._deannotate() + elem._copy_internals(clone=clone) + return elem + + if element is not None: + element = clone(element) + return element + + def splice_joins(left, right, stop_on=None): if left is None: return right @@ -208,7 +255,6 @@ def reduce_columns(columns, *clauses, **kw): in the the selectable to just those that are not repeated. """ - ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False) columns = util.OrderedSet(columns) @@ -317,7 +363,12 @@ def folded_equivalents(join, equivs=None): return collist class AliasedRow(object): + """Wrap a RowProxy with a translation map. + + This object allows a set of keys to be translated + to those present in a RowProxy. + """ def __init__(self, row, map): # AliasedRow objects don't nest, so un-nest # if another AliasedRow was passed @@ -341,10 +392,8 @@ class AliasedRow(object): class ClauseAdapter(visitors.ReplacingCloningVisitor): - """Given a clause (like as in a WHERE criterion), locate columns - which are embedded within a given selectable, and changes those - columns to be that of the selectable. - + """Clones and modifies clauses based on column correspondence. + E.g.:: table1 = Table('sometable', metadata, @@ -358,7 +407,7 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): condition = table1.c.col1 == table2.c.col1 - and make an alias of table1:: + make an alias of table1:: s = table1.alias('foo') @@ -401,7 +450,14 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor): return self._corresponding_column(col, True) class ColumnAdapter(ClauseAdapter): - + """Extends ClauseAdapter with extra utility functions. + + Provides the ability to "wrap" this ClauseAdapter + around another, a columns dictionary which returns + cached, adapted elements given an original, and an + adapted_row() factory. + + """ def __init__(self, selectable, equivalents=None, chain_to=None, include=None, exclude=None): ClauseAdapter.__init__(self, selectable, equivalents, include, exclude) if chain_to: diff --git a/test/ext/declarative.py b/test/ext/declarative.py index 772dca8e13..1c088d1434 100644 --- a/test/ext/declarative.py +++ b/test/ext/declarative.py @@ -6,7 +6,7 @@ from testlib import sa, testing from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey, ForeignKeyConstraint, asc from testlib.sa.orm import relation, create_session, class_mapper, eagerload, compile_mappers, backref from testlib.testing import eq_ -from orm._base import ComparableEntity +from orm._base import ComparableEntity, MappedTest class DeclarativeTest(testing.TestBase, testing.AssertsExecutionResults): @@ -784,7 +784,77 @@ class DeclarativeTest(testing.TestBase, testing.AssertsExecutionResults): finally: meta.drop_all() +def produce_test(inline, stringbased): + class ExplicitJoinTest(testing.ORMTest): + + def define_tables(self, metadata): + global User, Address + Base = decl.declarative_base(metadata=metadata) + class User(Base, ComparableEntity): + __tablename__ = 'users' + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + class Address(Base, ComparableEntity): + __tablename__ = 'addresses' + id = Column(Integer, primary_key=True) + email = Column(String(50)) + user_id = Column(Integer, ForeignKey('users.id')) + if inline: + if stringbased: + user = relation("User", primaryjoin="User.id==Address.user_id", backref="addresses") + else: + user = relation(User, primaryjoin=User.id==user_id, backref="addresses") + + if not inline: + compile_mappers() + if stringbased: + Address.user = relation("User", primaryjoin="User.id==Address.user_id", backref="addresses") + else: + Address.user = relation(User, primaryjoin=User.id==Address.user_id, backref="addresses") + + def insert_data(self): + params = [dict(zip(('id', 'name'), column_values)) for column_values in + [(7, 'jack'), + (8, 'ed'), + (9, 'fred'), + (10, 'chuck')] + ] + User.__table__.insert().execute(params) + + Address.__table__.insert().execute( + [dict(zip(('id', 'user_id', 'email'), column_values)) for column_values in + [(1, 7, "jack@bean.com"), + (2, 8, "ed@wood.com"), + (3, 8, "ed@bettyboop.com"), + (4, 8, "ed@lala.com"), + (5, 9, "fred@fred.com")] + ] + ) + + def test_aliased_join(self): + # this query will screw up if the aliasing + # enabled in query.join() gets applied to the right half of the join condition inside the any(). + # the join condition inside of any() comes from the "primaryjoin" of the relation, + # and should not be annotated with _orm_adapt. PropertyLoader.Comparator will annotate + # the left side with _orm_adapt, though. + sess = create_session() + eq_( + sess.query(User).join(User.addresses, aliased=True). + filter(Address.email=='ed@wood.com').filter(User.addresses.any(Address.email=='jack@bean.com')).all(), + [] + ) + + ExplicitJoinTest.__name__ = "ExplicitJoinTest%s%s" % (inline and 'Inline' or 'Separate', stringbased and 'String' or 'Literal') + return ExplicitJoinTest + +for inline in (True, False): + for stringbased in (True, False): + testclass = produce_test(inline, stringbased) + exec("%s = testclass" % testclass.__name__) + del testclass + class DeclarativeReflectionTest(testing.TestBase): def setUpAll(self): global reflection_metadata diff --git a/test/orm/inheritance/abc_inheritance.py b/test/orm/inheritance/abc_inheritance.py index e6977506ab..1a91df9875 100644 --- a/test/orm/inheritance/abc_inheritance.py +++ b/test/orm/inheritance/abc_inheritance.py @@ -163,7 +163,7 @@ for parent in ["a", "b", "c"]: for direction in [ONETOMANY, MANYTOONE]: testclass = produce_test(parent, child, direction) exec("%s = testclass" % testclass.__name__) - + del testclass if __name__ == "__main__": testenv.main() diff --git a/test/orm/inheritance/query.py b/test/orm/inheritance/query.py index eb40f01e56..601d5be6ca 100644 --- a/test/orm/inheritance/query.py +++ b/test/orm/inheritance/query.py @@ -258,6 +258,23 @@ def make_test(select_type): def test_polymorphic_any(self): sess = create_session() + self.assertEquals( + sess.query(Company).\ + filter(Company.employees.any(Person.name=='vlad')).all(), [c2] + ) + + # test that the aliasing on "Person" does not bleed into the + # EXISTS clause generated by any() + self.assertEquals( + sess.query(Company).join(Company.employees, aliased=True).filter(Person.name=='dilbert').\ + filter(Company.employees.any(Person.name=='wally')).all(), [c1] + ) + + self.assertEquals( + sess.query(Company).join(Company.employees, aliased=True).filter(Person.name=='dilbert').\ + filter(Company.employees.any(Person.name=='vlad')).all(), [] + ) + self.assertEquals( sess.query(Company).filter(Company.employees.of_type(Engineer).any(Engineer.primary_language=='cobol')).one(), c2 diff --git a/test/orm/query.py b/test/orm/query.py index dcd0ac548a..ab20754df7 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -41,6 +41,9 @@ class QueryTest(FixtureTest): }) mapper(Keyword, keywords) + compile_mappers() + #class_mapper(User).add_property('addresses', relation(Address, primaryjoin=User.id==Address.user_id, order_by=Address.id, backref='user')) + class UnicodeSchemaTest(QueryTest): keep_mappers = False @@ -356,11 +359,11 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): "nodes.id = nodes_1.parent_id AND nodes_1.data = :data_1)" ) - # fails, needs autoaliasing - #self._test( - # Node.children==None, - # "NOT (EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE nodes.id = nodes_1.parent_id))" - #) + # needs autoaliasing + self._test( + Node.children==None, + "NOT (EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE nodes.id = nodes_1.parent_id))" + ) self._test( Node.parent==None, @@ -372,44 +375,27 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): "nodes_1.parent_id IS NULL" ) - # fails, needs autoaliasing - #self._test( - # Node.children==[Node(id=1), Node(id=2)], - # "(EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE nodes.id = nodes_1.parent_id AND nodes_1.id = :id_1)) " - # "AND (EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE nodes.id = nodes_1.parent_id AND nodes_1.id = :id_2))" - #) - - # fails, overaliases - #self._test( - # nalias.children==[Node(id=1), Node(id=2)], - # "(EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE nodes.id = nodes_1.parent_id AND nodes_1.id = :id_1)) " - # "AND (EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE nodes.id = nodes_1.parent_id AND nodes_1.id = :id_2))" - #) - - # fails, overaliases - #self._test( - # nalias.children==None, - # "NOT (EXISTS (SELECT 1 FROM nodes AS nodes WHERE nodes_1.id = nodes.parent_id))" - #) + self._test( + nalias.children==None, + "NOT (EXISTS (SELECT 1 FROM nodes WHERE nodes_1.id = nodes.parent_id))" + ) - # fails - #self._test( - # nalias.children.any(Node.data=='some data'), - # "EXISTS (SELECT 1 FROM nodes WHERE " - # "nodes_1.id = nodes.parent_id AND nodes.data = :data_1)") + self._test( + nalias.children.any(Node.data=='some data'), + "EXISTS (SELECT 1 FROM nodes WHERE " + "nodes_1.id = nodes.parent_id AND nodes.data = :data_1)") - # fails + # fails, but I think I want this to fail #self._test( # Node.children.any(nalias.data=='some data'), # "EXISTS (SELECT 1 FROM nodes AS nodes_1 WHERE " # "nodes.id = nodes_1.parent_id AND nodes_1.data = :data_1)" # ) - # fails, overaliases - #self._test( - # nalias.parent.has(Node.data=='some data'), - # "EXISTS (SELECT 1 FROM nodes WHERE nodes.id = nodes_1.parent_id AND nodes.data = :data_1)" - #) + self._test( + nalias.parent.has(Node.data=='some data'), + "EXISTS (SELECT 1 FROM nodes WHERE nodes.id = nodes_1.parent_id AND nodes.data = :data_1)" + ) self._test( Node.parent.has(Node.data=='some data'), @@ -426,12 +412,10 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): ":param_1 = nodes_1.parent_id" ) - # fails - # (also why are we doing an EXISTS for this??) - #self._test( - # nalias.parent != Node(id=7), - # 'NOT (EXISTS (SELECT 1 FROM nodes WHERE nodes.id = nodes_1.parent_id AND nodes.id = :id_1))' - #) + self._test( + nalias.parent != Node(id=7), + 'nodes_1.parent_id != :parent_id_1 OR nodes_1.parent_id IS NULL' + ) self._test( nalias.children.contains(Node(id=7)), "nodes_1.id = :param_1" @@ -451,8 +435,7 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): def test_selfref_between(self): ualias = aliased(User) self._test(User.id.between(ualias.id, ualias.id), "users.id BETWEEN users_1.id AND users_1.id") - # fails: - # self._test(ualias.id.between(User.id, User.id), "users_1.id BETWEEN users.id AND users.id") + self._test(ualias.id.between(User.id, User.id), "users_1.id BETWEEN users.id AND users.id") def test_clauses(self): for (expr, compare) in ( @@ -569,6 +552,31 @@ class TextTest(QueryTest): def test_binds(self): assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (:id1, :id2)").params(id1=8, id2=9).all() + +class FooTest(FixtureTest): + keep_data = True + + def test_filter_by(self): + clear_mappers() + sess = create_session(bind=testing.db) + from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base(bind=testing.db) + class User(Base, _base.ComparableEntity): + __table__ = users + + class Address(Base, _base.ComparableEntity): + __table__ = addresses + + compile_mappers() +# Address.user = relation(User, primaryjoin="User.id==Address.user_id") + Address.user = relation(User, primaryjoin=User.id==Address.user_id) +# Address.user = relation(User, primaryjoin=users.c.id==addresses.c.user_id) + compile_mappers() +# Address.user.property.primaryjoin = User.id==Address.user_id + user = sess.query(User).get(8) + print sess.query(Address).filter_by(user=user).all() + assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter_by(user=user).all() + class FilterTest(QueryTest): def test_basic(self): assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).all() @@ -1134,12 +1142,12 @@ class JoinTest(QueryTest): assert q.count() == 1 assert [User(id=7)] == q.all() - # test the control version - same joins but not aliased. rows are not returned because order 3 does not have item 1 q = sess.query(User).join('orders').filter(Order.description=="order 3").join(['orders', 'items']).filter(Item.description=="item 1") assert [] == q.all() assert q.count() == 0 + # the left half of the join condition of the any() is aliased. q = sess.query(User).join('orders', aliased=True).filter(Order.items.any(Item.description=='item 4')) assert [User(id=7)] == q.all() diff --git a/test/sql/selectable.py b/test/sql/selectable.py index e41165b5bf..3f9464283d 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -5,11 +5,12 @@ every selectable unit behaving nicely with others..""" import testenv; testenv.configure_for_tests() from sqlalchemy import * from testlib import * -from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import util as sql_util, visitors from sqlalchemy import exc +from sqlalchemy.sql import table, column metadata = MetaData() -table = Table('table1', metadata, +table1 = Table('table1', metadata, Column('col1', Integer, primary_key=True), Column('col2', String(20)), Column('col3', Integer), @@ -27,16 +28,16 @@ table2 = Table('table2', metadata, class SelectableTest(TestBase, AssertsExecutionResults): def test_distance(self): # same column three times - s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')]) + s = select([table1.c.col1.label('c2'), table1.c.col1, table1.c.col1.label('c1')]) # didnt do this yet...col.label().make_proxy() has same "distance" as col.make_proxy() so far - #assert s.corresponding_column(table.c.col1) is s.c.col1 + #assert s.corresponding_column(table1.c.col1) is s.c.col1 assert s.corresponding_column(s.c.col1) is s.c.col1 assert s.corresponding_column(s.c.c1) is s.c.c1 def test_join_against_self(self): - jj = select([table.c.col1.label('bar_col1')]) - jjj = join(table, jj, table.c.col1==jj.c.bar_col1) + jj = select([table1.c.col1.label('bar_col1')]) + jjj = join(table1, jj, table1.c.col1==jj.c.bar_col1) # test column directly agaisnt itself assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 @@ -45,22 +46,22 @@ class SelectableTest(TestBase, AssertsExecutionResults): # test alias of the join, targets the column with the least # "distance" between the requested column and the returned column - # (i.e. there is less indirection between j2.c.table1_col1 and table.c.col1, than - # there is from j2.c.bar_col1 to table.c.col1) + # (i.e. there is less indirection between j2.c.table1_col1 and table1.c.col1, than + # there is from j2.c.bar_col1 to table1.c.col1) j2 = jjj.alias('foo') - assert j2.corresponding_column(table.c.col1) is j2.c.table1_col1 + assert j2.corresponding_column(table1.c.col1) is j2.c.table1_col1 def test_select_on_table(self): - sel = select([table, table2], use_labels=True) - assert sel.corresponding_column(table.c.col1) is sel.c.table1_col1 - assert sel.corresponding_column(table.c.col1, require_embedded=True) is sel.c.table1_col1 - assert table.corresponding_column(sel.c.table1_col1) is table.c.col1 - assert table.corresponding_column(sel.c.table1_col1, require_embedded=True) is None + sel = select([table1, table2], use_labels=True) + assert sel.corresponding_column(table1.c.col1) is sel.c.table1_col1 + assert sel.corresponding_column(table1.c.col1, require_embedded=True) is sel.c.table1_col1 + assert table1.corresponding_column(sel.c.table1_col1) is table1.c.col1 + assert table1.corresponding_column(sel.c.table1_col1, require_embedded=True) is None def test_join_against_join(self): - j = outerjoin(table, table2, table.c.col1==table2.c.col2) - jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo') - jjj = join(table, jj, table.c.col1==jj.c.bar_col1) + j = outerjoin(table1, table2, table1.c.col1==table2.c.col2) + jj = select([ table1.c.col1.label('bar_col1')],from_obj=[j]).alias('foo') + jjj = join(table1, jj, table1.c.col1==jj.c.bar_col1) assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 j2 = jjj.alias('foo') @@ -70,7 +71,7 @@ class SelectableTest(TestBase, AssertsExecutionResults): assert jjj.corresponding_column(jj.c.bar_col1) is jj.c.bar_col1 def test_table_alias(self): - a = table.alias('a') + a = table1.alias('a') j = join(a, table2) @@ -80,10 +81,10 @@ class SelectableTest(TestBase, AssertsExecutionResults): def test_union(self): # tests that we can correspond a column in a Select statement with a certain Table, against # a column in a Union where one of its underlying Selects matches to that same Table - u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( + u = select([table1.c.col1, table1.c.col2, table1.c.col3, table1.c.colx, null().label('coly')]).union( select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) ) - s1 = table.select(use_labels=True) + s1 = table1.select(use_labels=True) s2 = table2.select(use_labels=True) print ["%d %s" % (id(c),c.key) for c in u.c] c = u.corresponding_column(s1.c.table1_col2) @@ -94,19 +95,19 @@ class SelectableTest(TestBase, AssertsExecutionResults): assert u.corresponding_column(s2.c.table2_col2) is u.c.col2 def test_singular_union(self): - u = union(select([table.c.col1, table.c.col2, table.c.col3]), select([table.c.col1, table.c.col2, table.c.col3])) + u = union(select([table1.c.col1, table1.c.col2, table1.c.col3]), select([table1.c.col1, table1.c.col2, table1.c.col3])) - u = union(select([table.c.col1, table.c.col2, table.c.col3])) + u = union(select([table1.c.col1, table1.c.col2, table1.c.col3])) assert u.c.col1 assert u.c.col2 assert u.c.col3 def test_alias_union(self): # same as testunion, except its an alias of the union - u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( + u = select([table1.c.col1, table1.c.col2, table1.c.col3, table1.c.colx, null().label('coly')]).union( select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) ).alias('analias') - s1 = table.select(use_labels=True) + s1 = table1.select(use_labels=True) s2 = table2.select(use_labels=True) assert u.corresponding_column(s1.c.table1_col2) is u.c.col2 assert u.corresponding_column(s2.c.table2_col2) is u.c.col2 @@ -115,26 +116,26 @@ class SelectableTest(TestBase, AssertsExecutionResults): def test_select_union(self): # like testaliasunion, but off a Select off the union. - u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( + u = select([table1.c.col1, table1.c.col2, table1.c.col3, table1.c.colx, null().label('coly')]).union( select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) ).alias('analias') s = select([u]) - s1 = table.select(use_labels=True) + s1 = table1.select(use_labels=True) s2 = table2.select(use_labels=True) assert s.corresponding_column(s1.c.table1_col2) is s.c.col2 assert s.corresponding_column(s2.c.table2_col2) is s.c.col2 def test_union_against_join(self): # same as testunion, except its an alias of the union - u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( + u = select([table1.c.col1, table1.c.col2, table1.c.col3, table1.c.colx, null().label('coly')]).union( select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) ).alias('analias') - j1 = table.join(table2) + j1 = table1.join(table2) assert u.corresponding_column(j1.c.table1_colx) is u.c.colx assert j1.corresponding_column(u.c.colx) is j1.c.table1_colx def test_join(self): - a = join(table, table2) + a = join(table1, table2) print str(a.select(use_labels=True)) b = table2.alias('b') j = join(a, b) @@ -143,52 +144,40 @@ class SelectableTest(TestBase, AssertsExecutionResults): self.assert_(criterion.compare(j.onclause)) def test_select_alias(self): - a = table.select().alias('a') - print str(a.select()) + a = table1.select().alias('a') j = join(a, table2) criterion = a.c.col1 == table2.c.col2 - print criterion - print j.onclause self.assert_(criterion.compare(j.onclause)) def test_select_labels(self): - a = table.select(use_labels=True) + a = table1.select(use_labels=True) print str(a.select()) j = join(a, table2) criterion = a.c.table1_col1 == table2.c.col2 - print - print str(j) self.assert_(criterion.compare(j.onclause)) def test_column_labels(self): - a = select([table.c.col1.label('acol1'), table.c.col2.label('acol2'), table.c.col3.label('acol3')]) - print str(a) - print [c for c in a.columns] - print str(a.select()) + a = select([table1.c.col1.label('acol1'), table1.c.col2.label('acol2'), table1.c.col3.label('acol3')]) j = join(a, table2) criterion = a.c.acol1 == table2.c.col2 - print str(j) self.assert_(criterion.compare(j.onclause)) def test_labeled_select_correspoinding(self): - l1 = select([func.max(table.c.col1)]).label('foo') + l1 = select([func.max(table1.c.col1)]).label('foo') s = select([l1]) assert s.corresponding_column(l1).name == s.c.foo - s = select([table.c.col1, l1]) + s = select([table1.c.col1, l1]) assert s.corresponding_column(l1).name == s.c.foo def test_select_alias_labels(self): a = table2.select(use_labels=True).alias('a') - print str(a.select()) - j = join(a, table) + j = join(a, table1) - criterion = table.c.col1 == a.c.table2_col2 - print str(criterion) - print str(j.onclause) + criterion = table1.c.col1 == a.c.table2_col2 self.assert_(criterion.compare(j.onclause)) def test_table_joined_to_select_of_table(self): @@ -458,8 +447,6 @@ class DerivedTest(TestBase, AssertsExecutionResults): class AnnotationsTest(TestBase): def test_annotated_corresponding_column(self): - from sqlalchemy.sql import table, column - table1 = table('table1', column("col1")) s1 = select([table1.c.col1]) @@ -475,6 +462,48 @@ class AnnotationsTest(TestBase): assert inner.corresponding_column(t2.c.col1, require_embedded=False) is inner.corresponding_column(t2.c.col1, require_embedded=True) is inner.c.col1 assert inner.corresponding_column(t1.c.col1, require_embedded=False) is inner.corresponding_column(t1.c.col1, require_embedded=True) is inner.c.col1 + def test_annotated_visit(self): + table1 = table('table1', column("col1"), column("col2")) + + bin = table1.c.col1 == bindparam('foo', value=None) + assert str(bin) == "table1.col1 = :foo" + def visit_binary(b): + b.right = table1.c.col2 + + b2 = visitors.cloned_traverse(bin, {}, {'binary':visit_binary}) + assert str(b2) == "table1.col1 = table1.col2" + + b3 = visitors.cloned_traverse(bin._annotate({}), {}, {'binary':visit_binary}) + assert str(b3) == "table1.col1 = table1.col2" + + def visit_binary(b): + b.left = bindparam('bar') + + b4 = visitors.cloned_traverse(b2, {}, {'binary':visit_binary}) + assert str(b4) == ":bar = table1.col2" + + b5 = visitors.cloned_traverse(b3, {}, {'binary':visit_binary}) + assert str(b5) == ":bar = table1.col2" + + def test_deannotate(self): + table1 = table('table1', column("col1"), column("col2")) + + bin = table1.c.col1 == bindparam('foo', value=None) + + b2 = sql_util._deep_annotate(bin, {'_orm_adapt':True}) + b3 = sql_util._deep_deannotate(b2) + b4 = sql_util._deep_deannotate(bin) + + for elem in (b2._annotations, b2.left._annotations): + assert '_orm_adapt' in elem + + for elem in (b3._annotations, b3.left._annotations, b4._annotations, b4.left._annotations): + assert elem == {} + + assert b2.left is not bin.left + assert b3.left is not b2.left is not bin.left + assert b4.left is bin.left # since column is immutable + assert b4.right is not bin.right is not b2.right is not b3.right if __name__ == "__main__": testenv.main()