]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added **kw to ClauseElement.compare(), so that we can smarten up the "use_get"...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Aug 2009 22:21:02 +0000 (22:21 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Aug 2009 22:21:02 +0000 (22:21 +0000)
- many-to-one relation to a joined-table subclass now uses get()
  for a simple load (known as the "use_get" condition),
  i.e. Related->Sub(Base), without the need
  to redefine the primaryjoin condition in terms of the base
  table. [ticket:1186]
- specifying a foreign key with a declarative column,
  i.e. ForeignKey(MyRelatedClass.id) doesn't break the "use_get"
  condition from taking place [ticket:1492]

06CHANGES
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/expression.py
test/ext/test_declarative.py
test/orm/inheritance/test_basic.py

index 9ea5ec9f22bbb64a0674e8d419606c1116a6c154..4055baf15a21790f763d0484131ce42eaac1fac7 100644 (file)
--- a/06CHANGES
+++ b/06CHANGES
     - added "make_transient()" helper function which transforms a persistent/
       detached instance into a transient one (i.e. deletes the instance_key
       and removes from any session.) [ticket:1052]
-      
+    - many-to-one "lazyload" fixes:
+        - many-to-one relation to a joined-table subclass now uses get()
+          for a simple load (known as the "use_get" condition), 
+          i.e. Related->Sub(Base), without the need
+          to redefine the primaryjoin condition in terms of the base
+          table. [ticket:1186]
+        - specifying a foreign key with a declarative column,
+          i.e. ForeignKey(MyRelatedClass.id) doesn't break the "use_get"
+          condition from taking place [ticket:1492]
+    
 - sql
     - returning() support is native to insert(), update(), delete().  Implementations
       of varying levels of functionality exist for Postgresql, Firebird, MSSQL and
index c0609dba3a3df0330405cf1256da4c7413990cf8..23114cdab2499a65f569fea873754de6dbe49ebe 100644 (file)
@@ -372,8 +372,18 @@ class LazyLoader(AbstractRelationLoader):
         # determine if our "lazywhere" clause is the same as the mapper's
         # get() clause.  then we can just use mapper.get()
         #from sqlalchemy.orm import query
-        self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.__lazywhere)
+        self.use_get = not self.uselist and \
+                        self.mapper._get_clause[0].compare(
+                            self.__lazywhere, 
+                            use_proxies=True, 
+                            equivalents=self.mapper._equivalent_columns
+                        )
         if self.use_get:
+            for col in self._equated_columns.keys():
+                if col in self.mapper._equivalent_columns:
+                    for c in self.mapper._equivalent_columns[col]:
+                        self._equated_columns[c] = self._equated_columns[col]
+            
             self.logger.info("%s will use query.get() to optimize instance loads" % self)
 
     def init_class_attribute(self, mapper):
index 960fc031032dde93cbbdce395c630a558ea599f6..8c6877dbd03c6af535a7218e71a61993c79ed4af 100644 (file)
@@ -1101,11 +1101,15 @@ class ClauseElement(Visitable):
                 bind._convert_to_unique()
         return cloned_traverse(self, {}, {'bindparam':visit_bindparam})
 
-    def compare(self, other):
+    def compare(self, other, **kw):
         """Compare this ClauseElement to the given ClauseElement.
 
         Subclasses should override the default behavior, which is a
         straight identity comparison.
+        
+        **kw are arguments consumed by subclass compare() methods and
+        may be used to modify the criteria for comparison.
+        (see :class:`ColumnElement`)
 
         """
         return self is other
@@ -1697,6 +1701,34 @@ class ColumnElement(ClauseElement, _CompareMixin):
         selectable.columns[name] = co
         return co
 
+    def compare(self, other, use_proxies=False, equivalents=None, **kw):
+        """Compare this ColumnElement to another.
+        
+        Special arguments understood:
+        
+        :param use_proxies: when True, consider two columns that
+        share a common base column as equivalent (i.e. shares_lineage())
+        
+        :param equivalents: a dictionary of columns as keys mapped to sets
+        of columns.  If the given "other" column is present in this dictionary,
+        if any of the columns in the correponding set() pass the comparison 
+        test, the result is True.  This is used to expand the comparison to
+        other columns that may be known to be equivalent to this one via 
+        foreign key or other criterion.
+
+        """
+        to_compare = (other, )
+        if equivalents and other in equivalents:
+            to_compare = equivalents[other].union(to_compare)
+
+        for oth in to_compare:
+            if use_proxies and self.shares_lineage(oth):
+                return True
+            elif oth is self:
+                return True
+        else:
+            return False
+
     @util.memoized_property
     def anon_label(self):
         """provides a constant 'anonymous label' for this ColumnElement.
@@ -2109,7 +2141,7 @@ class _BindParamClause(ColumnElement):
         else:
             return obj.type
 
-    def compare(self, other):
+    def compare(self, other, **kw):
         """Compare this ``_BindParamClause`` to the given clause.
 
         Since ``compare()`` is meant to compare statement syntax, this
@@ -2274,16 +2306,16 @@ class ClauseList(ClauseElement):
         else:
             return self
 
-    def compare(self, other):
+    def compare(self, other, **kw):
         """Compare this ``ClauseList`` to the given ``ClauseList``,
         including a comparison of all the clause items.
 
         """
         if not isinstance(other, ClauseList) and len(self.clauses) == 1:
-            return self.clauses[0].compare(other)
+            return self.clauses[0].compare(other, **kw)
         elif isinstance(other, ClauseList) and len(self.clauses) == len(other.clauses):
             for i in range(0, len(self.clauses)):
-                if not self.clauses[i].compare(other.clauses[i]):
+                if not self.clauses[i].compare(other.clauses[i], **kw):
                     return False
             else:
                 return self.operator == other.operator
@@ -2473,14 +2505,14 @@ class _UnaryExpression(ColumnElement):
     def get_children(self, **kwargs):
         return self.element,
 
-    def compare(self, other):
+    def compare(self, other, **kw):
         """Compare this ``_UnaryExpression`` against the given ``ClauseElement``."""
 
         return (
             isinstance(other, _UnaryExpression) and
             self.operator == other.operator and
             self.modifier == other.modifier and
-            self.element.compare(other.element)
+            self.element.compare(other.element, **kw)
         )
 
     def _negate(self):
@@ -2528,19 +2560,19 @@ class _BinaryExpression(ColumnElement):
     def get_children(self, **kwargs):
         return self.left, self.right
 
-    def compare(self, other):
+    def compare(self, other, **kw):
         """Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``."""
 
         return (
             isinstance(other, _BinaryExpression) and
             self.operator == other.operator and
             (
-                self.left.compare(other.left) and
-                self.right.compare(other.right) or
+                self.left.compare(other.left, **kw) and
+                self.right.compare(other.right, **kw) or
                 (
                     operators.is_commutative(self.operator) and
-                    self.left.compare(other.right) and
-                    self.right.compare(other.left)
+                    self.left.compare(other.right, **kw) and
+                    self.right.compare(other.left, **kw)
                 )
             )
         )
index 745e3b7cf8cab3169a8bf7ae6472e11518a1e7b5..63beb54e7a1352616935f4da1243e31fc8ad7a0b 100644 (file)
@@ -246,6 +246,37 @@ class DeclarativeTest(DeclarativeTestBase):
         Base = decl.declarative_base(cls=MyBase)
         assert hasattr(Base, 'metadata')
         assert Base().foobar() == "foobar"
+    
+    def test_uses_get_on_class_col_fk(self):
+        # test [ticket:1492]
+        
+        class Master(Base): 
+            __tablename__ = 'master' 
+            id = Column(Integer, primary_key=True) 
+
+        class Detail(Base): 
+            __tablename__ = 'detail' 
+            id = Column(Integer, primary_key=True) 
+            master_id = Column(None, ForeignKey(Master.id)) 
+            master = relation(Master) 
+
+        Base.metadata.create_all()
+        
+        compile_mappers()
+        assert class_mapper(Detail).get_property('master').strategy.use_get
+        
+        m1 = Master()
+        d1 = Detail(master=m1)
+        sess = create_session()
+        sess.add(d1)
+        sess.flush()
+        sess.expunge_all()
+
+        d1 = sess.query(Detail).first()
+        m1 = sess.query(Master).first()
+        def go():
+            assert d1.master
+        self.assert_sql_count(testing.db, go, 0)
         
     def test_index_doesnt_compile(self):
         class User(Base):
index b2e00de3598261c24b53847425964b128ea35cdb..778b08a272a07ab1808f414b3fdebc572cbee2ca 100644 (file)
@@ -208,6 +208,58 @@ class CascadeTest(_base.MappedTest):
         assert t4_1 in sess.deleted
         sess.flush()
 
+class M2OUseGetTest(_base.MappedTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('base', metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('type', String(30))
+        )
+        Table('sub', metadata,
+            Column('id', Integer, ForeignKey('base.id'), primary_key=True),
+        )
+        Table('related', metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('sub_id', Integer, ForeignKey('sub.id')),
+        )
+
+    @testing.resolve_artifact_names
+    def test_use_get(self):
+        # test [ticket:1186]
+        class Base(_base.BasicEntity):
+            pass
+        class Sub(Base):
+            pass
+        class Related(Base):
+            pass
+        mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='b')
+        mapper(Sub, sub, inherits=Base, polymorphic_identity='s')
+        mapper(Related, related, properties={
+            # previously, this was needed for the comparison to occur:
+            # the 'primaryjoin' looks just like "Sub"'s "get" clause (based on the Base id),
+            # and foreign_keys since that join condition doesn't actually have any fks in it
+            #'sub':relation(Sub, primaryjoin=base.c.id==related.c.sub_id, foreign_keys=related.c.sub_id)
+            
+            # now we can use this:
+            'sub':relation(Sub)
+        })
+        
+        assert class_mapper(Related).get_property('sub').strategy.use_get
+        
+        sess = create_session()
+        s1 = Sub()
+        r1 = Related(sub=s1)
+        sess.add(r1)
+        sess.flush()
+        sess.expunge_all()
+
+        r1 = sess.query(Related).first()
+        s1 = sess.query(Sub).first()
+        def go():
+            assert r1.sub
+        self.assert_sql_count(testing.db, go, 0)
+        
+
 class GetTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):