From: Mike Bayer Date: Sat, 8 Aug 2009 22:21:02 +0000 (+0000) Subject: - added **kw to ClauseElement.compare(), so that we can smarten up the "use_get"... X-Git-Tag: rel_0_6beta1~347 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=a04da2a417726055da205bd604a1993eb4fd9887;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - added **kw to ClauseElement.compare(), so that we can smarten up the "use_get" operation - many-to-one relation to a joined-table subclass now uses get() for a simple load (known as the "use_get" condition), i.e. Related->Sub(Base), without the need to redefine the primaryjoin condition in terms of the base table. [ticket:1186] - specifying a foreign key with a declarative column, i.e. ForeignKey(MyRelatedClass.id) doesn't break the "use_get" condition from taking place [ticket:1492] --- diff --git a/06CHANGES b/06CHANGES index 9ea5ec9f22..4055baf15a 100644 --- a/06CHANGES +++ b/06CHANGES @@ -14,7 +14,16 @@ - added "make_transient()" helper function which transforms a persistent/ detached instance into a transient one (i.e. deletes the instance_key and removes from any session.) [ticket:1052] - + - many-to-one "lazyload" fixes: + - many-to-one relation to a joined-table subclass now uses get() + for a simple load (known as the "use_get" condition), + i.e. Related->Sub(Base), without the need + to redefine the primaryjoin condition in terms of the base + table. [ticket:1186] + - specifying a foreign key with a declarative column, + i.e. ForeignKey(MyRelatedClass.id) doesn't break the "use_get" + condition from taking place [ticket:1492] + - sql - returning() support is native to insert(), update(), delete(). Implementations of varying levels of functionality exist for Postgresql, Firebird, MSSQL and diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index c0609dba3a..23114cdab2 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -372,8 +372,18 @@ class LazyLoader(AbstractRelationLoader): # determine if our "lazywhere" clause is the same as the mapper's # get() clause. then we can just use mapper.get() #from sqlalchemy.orm import query - self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.__lazywhere) + self.use_get = not self.uselist and \ + self.mapper._get_clause[0].compare( + self.__lazywhere, + use_proxies=True, + equivalents=self.mapper._equivalent_columns + ) if self.use_get: + for col in self._equated_columns.keys(): + if col in self.mapper._equivalent_columns: + for c in self.mapper._equivalent_columns[col]: + self._equated_columns[c] = self._equated_columns[col] + self.logger.info("%s will use query.get() to optimize instance loads" % self) def init_class_attribute(self, mapper): diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 960fc03103..8c6877dbd0 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1101,11 +1101,15 @@ class ClauseElement(Visitable): bind._convert_to_unique() return cloned_traverse(self, {}, {'bindparam':visit_bindparam}) - def compare(self, other): + def compare(self, other, **kw): """Compare this ClauseElement to the given ClauseElement. Subclasses should override the default behavior, which is a straight identity comparison. + + **kw are arguments consumed by subclass compare() methods and + may be used to modify the criteria for comparison. + (see :class:`ColumnElement`) """ return self is other @@ -1697,6 +1701,34 @@ class ColumnElement(ClauseElement, _CompareMixin): selectable.columns[name] = co return co + def compare(self, other, use_proxies=False, equivalents=None, **kw): + """Compare this ColumnElement to another. + + Special arguments understood: + + :param use_proxies: when True, consider two columns that + share a common base column as equivalent (i.e. shares_lineage()) + + :param equivalents: a dictionary of columns as keys mapped to sets + of columns. If the given "other" column is present in this dictionary, + if any of the columns in the correponding set() pass the comparison + test, the result is True. This is used to expand the comparison to + other columns that may be known to be equivalent to this one via + foreign key or other criterion. + + """ + to_compare = (other, ) + if equivalents and other in equivalents: + to_compare = equivalents[other].union(to_compare) + + for oth in to_compare: + if use_proxies and self.shares_lineage(oth): + return True + elif oth is self: + return True + else: + return False + @util.memoized_property def anon_label(self): """provides a constant 'anonymous label' for this ColumnElement. @@ -2109,7 +2141,7 @@ class _BindParamClause(ColumnElement): else: return obj.type - def compare(self, other): + def compare(self, other, **kw): """Compare this ``_BindParamClause`` to the given clause. Since ``compare()`` is meant to compare statement syntax, this @@ -2274,16 +2306,16 @@ class ClauseList(ClauseElement): else: return self - def compare(self, other): + def compare(self, other, **kw): """Compare this ``ClauseList`` to the given ``ClauseList``, including a comparison of all the clause items. """ if not isinstance(other, ClauseList) and len(self.clauses) == 1: - return self.clauses[0].compare(other) + return self.clauses[0].compare(other, **kw) elif isinstance(other, ClauseList) and len(self.clauses) == len(other.clauses): for i in range(0, len(self.clauses)): - if not self.clauses[i].compare(other.clauses[i]): + if not self.clauses[i].compare(other.clauses[i], **kw): return False else: return self.operator == other.operator @@ -2473,14 +2505,14 @@ class _UnaryExpression(ColumnElement): def get_children(self, **kwargs): return self.element, - def compare(self, other): + def compare(self, other, **kw): """Compare this ``_UnaryExpression`` against the given ``ClauseElement``.""" return ( isinstance(other, _UnaryExpression) and self.operator == other.operator and self.modifier == other.modifier and - self.element.compare(other.element) + self.element.compare(other.element, **kw) ) def _negate(self): @@ -2528,19 +2560,19 @@ class _BinaryExpression(ColumnElement): def get_children(self, **kwargs): return self.left, self.right - def compare(self, other): + def compare(self, other, **kw): """Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``.""" return ( isinstance(other, _BinaryExpression) and self.operator == other.operator and ( - self.left.compare(other.left) and - self.right.compare(other.right) or + self.left.compare(other.left, **kw) and + self.right.compare(other.right, **kw) or ( operators.is_commutative(self.operator) and - self.left.compare(other.right) and - self.right.compare(other.left) + self.left.compare(other.right, **kw) and + self.right.compare(other.left, **kw) ) ) ) diff --git a/test/ext/test_declarative.py b/test/ext/test_declarative.py index 745e3b7cf8..63beb54e7a 100644 --- a/test/ext/test_declarative.py +++ b/test/ext/test_declarative.py @@ -246,6 +246,37 @@ class DeclarativeTest(DeclarativeTestBase): Base = decl.declarative_base(cls=MyBase) assert hasattr(Base, 'metadata') assert Base().foobar() == "foobar" + + def test_uses_get_on_class_col_fk(self): + # test [ticket:1492] + + class Master(Base): + __tablename__ = 'master' + id = Column(Integer, primary_key=True) + + class Detail(Base): + __tablename__ = 'detail' + id = Column(Integer, primary_key=True) + master_id = Column(None, ForeignKey(Master.id)) + master = relation(Master) + + Base.metadata.create_all() + + compile_mappers() + assert class_mapper(Detail).get_property('master').strategy.use_get + + m1 = Master() + d1 = Detail(master=m1) + sess = create_session() + sess.add(d1) + sess.flush() + sess.expunge_all() + + d1 = sess.query(Detail).first() + m1 = sess.query(Master).first() + def go(): + assert d1.master + self.assert_sql_count(testing.db, go, 0) def test_index_doesnt_compile(self): class User(Base): diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index b2e00de359..778b08a272 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -208,6 +208,58 @@ class CascadeTest(_base.MappedTest): assert t4_1 in sess.deleted sess.flush() +class M2OUseGetTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('base', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('type', String(30)) + ) + Table('sub', metadata, + Column('id', Integer, ForeignKey('base.id'), primary_key=True), + ) + Table('related', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('sub_id', Integer, ForeignKey('sub.id')), + ) + + @testing.resolve_artifact_names + def test_use_get(self): + # test [ticket:1186] + class Base(_base.BasicEntity): + pass + class Sub(Base): + pass + class Related(Base): + pass + mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='b') + mapper(Sub, sub, inherits=Base, polymorphic_identity='s') + mapper(Related, related, properties={ + # previously, this was needed for the comparison to occur: + # the 'primaryjoin' looks just like "Sub"'s "get" clause (based on the Base id), + # and foreign_keys since that join condition doesn't actually have any fks in it + #'sub':relation(Sub, primaryjoin=base.c.id==related.c.sub_id, foreign_keys=related.c.sub_id) + + # now we can use this: + 'sub':relation(Sub) + }) + + assert class_mapper(Related).get_property('sub').strategy.use_get + + sess = create_session() + s1 = Sub() + r1 = Related(sub=s1) + sess.add(r1) + sess.flush() + sess.expunge_all() + + r1 = sess.query(Related).first() + s1 = sess.query(Sub).first() + def go(): + assert r1.sub + self.assert_sql_count(testing.db, go, 0) + + class GetTest(_base.MappedTest): @classmethod def define_tables(cls, metadata):