]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
initial annotations approach to join conditions. all tests pass, plus additional...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 6 Feb 2012 17:20:15 +0000 (12:20 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 6 Feb 2012 17:20:15 +0000 (12:20 -0500)
would now like to reorganize RelationshipProperty more around the annotations concept.

lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
test/orm/test_joins.py
test/orm/test_query.py
test/orm/test_relationships.py
test/sql/test_selectable.py

index 59c4cb3dc1932ba46dec9e71f4c8ee8a95a5ec63..a590ad7069d3d8b8d46b3658028e4ab6c34488a2 100644 (file)
@@ -14,7 +14,7 @@ mapped attributes.
 from sqlalchemy import sql, util, log, exc as sa_exc
 from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, \
     join_condition, _shallow_annotate
-from sqlalchemy.sql import operators, expression
+from sqlalchemy.sql import operators, expression, visitors
 from sqlalchemy.orm import attributes, dependency, mapper, \
     object_mapper, strategies, configure_mappers
 from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, \
@@ -444,6 +444,7 @@ class RelationshipProperty(StrategizedProperty):
             else:
                 j = _orm_annotate(pj, exclude=self.property.remote_side)
 
+            # MARKMARK
             if criterion is not None and target_adapter:
                 # limit this adapter to annotated only?
                 criterion = target_adapter.traverse(criterion)
@@ -1376,6 +1377,34 @@ class RelationshipProperty(StrategizedProperty):
                             "argument to indicate which column lazy join "
                             "condition should bind." % (col, self.mapper))
 
+        count = [0]
+        def clone(elem):
+            if set(['local', 'remote']).intersection(elem._annotations):
+                return None
+            elif elem in self.local_side and elem in self.remote_side:
+                # TODO: OK this still sucks.  this is basically,
+                # refuse, refuse, refuse the temptation to guess!
+                # but crap we really have to guess don't we.  we 
+                # might want to traverse here with cloned_traverse
+                # so we can see the binary exprs and do it at that 
+                # level....
+                if count[0] % 2 == 0:
+                    elem = elem._annotate({'local':True})
+                else:
+                    elem = elem._annotate({'remote':True})
+                count[0] += 1
+            elif elem in self.local_side:
+                elem = elem._annotate({'local':True})
+            elif elem in self.remote_side:
+                elem = elem._annotate({'remote':True})
+            else:
+                elem = None
+            return elem
+
+        self.primaryjoin = visitors.replacement_traverse(
+                                self.primaryjoin, {}, clone
+                            )
+
     def _generate_backref(self):
         if not self.is_primary():
             return
@@ -1539,17 +1568,20 @@ class RelationshipProperty(StrategizedProperty):
                     secondary_aliasizer.traverse(secondaryjoin)
             else:
                 primary_aliasizer = ClauseAdapter(dest_selectable,
-                        exclude=self.local_side,
+                        #exclude=self.local_side,
+                        exclude_fn=lambda c: "local" in c._annotations,
                         equivalents=self.mapper._equivalent_columns)
                 if source_selectable is not None:
                     primary_aliasizer.chain(
                         ClauseAdapter(source_selectable,
-                            exclude=self.remote_side,
+                            #exclude=self.remote_side,
+                            exclude_fn=lambda c: "remote" in c._annotations,
                             equivalents=self.parent._equivalent_columns))
                 secondary_aliasizer = None
+
             primaryjoin = primary_aliasizer.traverse(primaryjoin)
             target_adapter = secondary_aliasizer or primary_aliasizer
-            target_adapter.include = target_adapter.exclude = None
+            target_adapter.include = target_adapter.exclude = target_adapter.exclude_fn = None
         else:
             target_adapter = None
         if source_selectable is None:
index b11e5ad429845bf22df7d2d2d104054ecdfa9e0d..30e19bc6864237998756369e2e92d342f61f63fb 100644 (file)
@@ -2184,7 +2184,7 @@ class ColumnElement(ClauseElement, _CompareMixin):
         for oth in to_compare:
             if use_proxies and self.shares_lineage(oth):
                 return True
-            elif oth is self:
+            elif hash(oth) == hash(self):
                 return True
         else:
             return False
index 97975441e48e2414f204d4659b04b52b6656076c..f0509c16f73a987535a8242ceff77a2e62c35608 100644 (file)
@@ -225,7 +225,8 @@ def adapt_criterion_to_null(crit, nulls):
 
     return visitors.cloned_traverse(crit, {}, {'binary':visit_binary})
 
-def join_condition(a, b, ignore_nonexistent_tables=False, a_subset=None):
+def join_condition(a, b, ignore_nonexistent_tables=False, 
+                            a_subset=None):
     """create a join condition between two tables or selectables.
 
     e.g.::
@@ -535,6 +536,10 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
                                 "'consider_as_foreign_keys' or "
                                 "'consider_as_referenced_keys'")
 
+    def col_is(a, b):
+        #return a is b
+        return a.compare(b)
+
     def visit_binary(binary):
         if not any_operator and binary.operator is not operators.eq:
             return
@@ -544,20 +549,20 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
 
         if consider_as_foreign_keys:
             if binary.left in consider_as_foreign_keys and \
-                        (binary.right is binary.left or 
+                        (col_is(binary.right, binary.left) or 
                         binary.right not in consider_as_foreign_keys):
                 pairs.append((binary.right, binary.left))
             elif binary.right in consider_as_foreign_keys and \
-                        (binary.left is binary.right or 
+                        (col_is(binary.left, binary.right) or 
                         binary.left not in consider_as_foreign_keys):
                 pairs.append((binary.left, binary.right))
         elif consider_as_referenced_keys:
             if binary.left in consider_as_referenced_keys and \
-                        (binary.right is binary.left or 
+                        (col_is(binary.right, binary.left) or 
                         binary.right not in consider_as_referenced_keys):
                 pairs.append((binary.left, binary.right))
             elif binary.right in consider_as_referenced_keys and \
-                        (binary.left is binary.right or 
+                        (col_is(binary.left, binary.right) or 
                         binary.left not in consider_as_referenced_keys):
                 pairs.append((binary.right, binary.left))
         else:
@@ -669,11 +674,22 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
       s.c.col1 == table2.c.col1
 
     """
-    def __init__(self, selectable, equivalents=None, include=None, exclude=None, adapt_on_names=False):
+    def __init__(self, selectable, equivalents=None, 
+                        include=None, exclude=None, 
+                        include_fn=None, exclude_fn=None, 
+                        adapt_on_names=False):
         self.__traverse_options__ = {'stop_on':[selectable]}
         self.selectable = selectable
-        self.include = include
-        self.exclude = exclude
+        if include:
+            assert not include_fn
+            self.include_fn = lambda e: e in include
+        else:
+            self.include_fn = include_fn
+        if exclude:
+            assert not exclude_fn
+            self.exclude_fn = lambda e: e in exclude
+        else:
+            self.exclude_fn = exclude_fn
         self.equivalents = util.column_dict(equivalents or {})
         self.adapt_on_names = adapt_on_names
 
@@ -693,19 +709,17 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
         return newcol
 
     def replace(self, col):
-        if isinstance(col, expression.FromClause):
-            if self.selectable.is_derived_from(col):
+        if isinstance(col, expression.FromClause) and \
+            self.selectable.is_derived_from(col):
                 return self.selectable
-
-        if not isinstance(col, expression.ColumnElement):
+        elif not isinstance(col, expression.ColumnElement):
             return None
-
-        if self.include and col not in self.include:
+        elif self.include_fn and not self.include_fn(col):
             return None
-        elif self.exclude and col in self.exclude:
+        elif self.exclude_fn and self.exclude_fn(col):
             return None
-
-        return self._corresponding_column(col, True)
+        else:
+            return self._corresponding_column(col, True)
 
 class ColumnAdapter(ClauseAdapter):
     """Extends ClauseAdapter with extra utility functions.
index cdcf40aa86ba5c7ecaa4d89b15d7f7c84d406f18..75e099f0d9958dfc0126cf1ad11f9d92a854bf06 100644 (file)
@@ -240,16 +240,16 @@ def replacement_traverse(obj, opts, replace):
     replacement by a given replacement function."""
 
     cloned = util.column_dict()
-    stop_on = util.column_set(opts.get('stop_on', []))
+    stop_on = util.column_set([id(x) for x in opts.get('stop_on', [])])
 
     def clone(elem, **kw):
-        if elem in stop_on or \
+        if id(elem) in stop_on or \
             'no_replacement_traverse' in elem._annotations:
             return elem
         else:
             newelem = replace(elem)
             if newelem is not None:
-                stop_on.add(newelem)
+                stop_on.add(id(newelem))
                 return newelem
             else:
                 if elem not in cloned:
index db7c78cdd697bde82f8901407a52e54f520a6d23..6c43a2f39840fabf43e1d8c0746283eea73556b0 100644 (file)
@@ -1700,21 +1700,29 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
         sess.flush()
         sess.close()
 
-    def test_join(self):
+    def test_join_1(self):
         Node = self.classes.Node
-
         sess = create_session()
 
         node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first()
         assert node.data=='n12'
 
+    def test_join_2(self):
+        Node = self.classes.Node
+        sess = create_session()
         ret = sess.query(Node.data).join(Node.children, aliased=True).filter_by(data='n122').all()
         assert ret == [('n12',)]
 
 
+    def test_join_3(self):
+        Node = self.classes.Node
+        sess = create_session()
         node = sess.query(Node).join('children', 'children', aliased=True).filter_by(data='n122').first()
         assert node.data=='n1'
 
+    def test_join_4(self):
+        Node = self.classes.Node
+        sess = create_session()
         node = sess.query(Node).filter_by(data='n122').join('parent', aliased=True).filter_by(data='n12').\
             join('parent', aliased=True, from_joinpoint=True).filter_by(data='n1').first()
         assert node.data == 'n122'
index 24974ae7e3f22fb2ba8388923df543e157c7d876..155f7c68d55f9f4192b11aae7777f070056e8367 100644 (file)
@@ -622,6 +622,14 @@ class OperatorTest(QueryTest, AssertsCompiledSQL):
         self._test(Address.user == None, "addresses.user_id IS NULL")
 
         self._test(Address.user != None, "addresses.user_id IS NOT NULL")
+    
+    def test_foo(self):
+        Node = self.classes.Node
+        nalias = aliased(Node)
+        self._test(
+            nalias.parent.has(Node.data=='some data'), 
+           "EXISTS (SELECT 1 FROM nodes WHERE nodes.id = nodes_1.parent_id AND nodes.data = :data_1)"
+        )
 
     def test_selfref_relationship(self):
         Node = self.classes.Node
index 6781d71045f4ad5619620a32851cd9a5437f757d..2049088aff29680bb0c9af1efd668bebd8664341 100644 (file)
@@ -249,6 +249,8 @@ class CompositeSelfRefFKTest(fixtures.MappedTest):
 
     def _test(self):
         Employee, Company = self.classes.Employee, self.classes.Company
+#        employee_t = self.tables.employee_t
+#        assert Employee.reports_to.property.local_remote_pairs == [(employee_t.c.reports_to_id, employee_t.c.emp_id), (employee_t.c.company_id, employee_t.c.company_id)]
 
         sess = create_session()
         c1 = Company()
index 8f599f1d6dd675a4d3e99f74d1dc5a31d6410b77..6d85f7c4f3e16f906cc2d0ef1b0bf2e54196ef7e 100644 (file)
@@ -1023,6 +1023,25 @@ class AnnotationsTest(fixtures.TestBase):
             annot = obj._annotate({})
             eq_(set([obj]), set([annot]))
 
+    def test_compare(self):
+        t = table('t', column('x'), column('y'))
+        x_a = t.c.x._annotate({})
+        assert t.c.x.compare(x_a)
+        assert x_a.compare(t.c.x)
+        assert not x_a.compare(t.c.y)
+        assert not t.c.y.compare(x_a)
+        assert (t.c.x == 5).compare(x_a == 5)
+        assert not (t.c.y == 5).compare(x_a == 5)
+
+        s = select([t])
+        x_p = s.c.x
+        assert not x_a.compare(x_p)
+        assert not t.c.x.compare(x_p)
+        x_p_a = x_p._annotate({})
+        assert x_p_a.compare(x_p)
+        assert x_p.compare(x_p_a)
+        assert not x_p_a.compare(x_a)
+
     def test_custom_constructions(self):
         from sqlalchemy.schema import Column
         class MyColumn(Column):