]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added concept of 'require_embedded' to corresponding_column.
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Mar 2007 23:08:52 +0000 (23:08 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Mar 2007 23:08:52 +0000 (23:08 +0000)
requires that the target column be present in a sub-element of the
target selectable.
- embedded logic above more appropriate for ClauseAdapter functionality
since its trying to "pull up" clauses that represent columns within
a larger union up to the level of the union itself.
- the "direction" test against the "foreign_keys" collection apparently
works for an exact "column 'x' is present in the collection", no proxy
relationships needed.  fixes the case of relating a selectable/alias
to one of its underlying tables, probably fixes other scenarios

CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/sql_util.py
test/orm/relationships.py

diff --git a/CHANGES b/CHANGES
index e830654e40f28d2412620d0c01a84c458971c830..e3fc2832f0901b55c13950c2babd93312fba7cf9 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -31,7 +31,9 @@
     - added "refresh-expire" cascade [ticket:492]
     - more fixes to polymorphic relations, involving proper lazy-clause
       generation on many-to-one relationships to polymorphic mappers 
-      [ticket:493]
+      [ticket:493]. also fixes to detection of "direction", more specific
+      targeting of columns that belong to the polymorphic union vs. those
+      that dont.
     - put an aggressive check for "flushing object A with a collection
       of B's, but you put a C in the collection" error condition - 
       **even if C is a subclass of B**, unless B's mapper loads polymorphically.
@@ -39,6 +41,7 @@
       (since its not polymorphic) which breaks in bi-directional relationships
       (i.e. C has its A, but A's backref will lazyload it as a different 
       instance of type "B") [ticket:500]
+
 0.3.5
 - sql:
     - the value of "case_sensitive" defaults to True now, regardless of the
index 19cde38628f3739564fb7b0977520531fb4db63d..5d5c42208ccbe10d386790d1ffe2867fc04c7ea5 100644 (file)
@@ -391,9 +391,9 @@ class ANSICompiler(sql.Compiled):
                 c.accept_visitor(self)
                 inner_columns[self.get_str(c)] = c
                 continue
-            try:
+            if hasattr(c, '_selectable'):
                 s = c._selectable()
-            except AttributeError:
+            else:
                 c.accept_visitor(self)
                 inner_columns[self.get_str(c)] = c
                 continue
index 98b386edd1a21945d062bbafb6eb44fa6c00daac..95f6c1b3b95f0d48aba04dee5ed9ff50bb0bef86 100644 (file)
@@ -302,8 +302,9 @@ class PropertyLoader(StrategizedProperty):
             else:
                 self.direction = sync.ONETOMANY
         else:
-            onetomany = len([c for c in self.foreign_keys if self.mapper.unjoined_table.corresponding_column(c, False) is not None])
-            manytoone = len([c for c in self.foreign_keys if self.parent.unjoined_table.corresponding_column(c, False) is not None])
+            onetomany = len([c for c in self.foreign_keys if self.mapper.unjoined_table.c.contains_column(c)])
+            manytoone = len([c for c in self.foreign_keys if self.parent.unjoined_table.c.contains_column(c)])
+
             if not onetomany and not manytoone:
                 raise exceptions.ArgumentError("Cant determine relation direction for relationship '%s' - foreign key columns are not present in neither the parent nor the child's mapped tables" %(str(self)))
             elif onetomany and manytoone:
index da1afe7992a19ebb90f372731192b7e6de5e78d1..073277d576b1809ebf977d3bfa3a9ef7f92c5890 100644 (file)
@@ -506,6 +506,21 @@ class ClauseVisitor(object):
     def visit_label(self, label):pass
     def visit_typeclause(self, typeclause):pass
 
+class VisitColumnMixin(object):
+    """a mixin that adds Column traversal to a ClauseVisitor"""
+    def visit_table(self, table):
+        for c in table.c:
+            c.accept_visitor(self)
+    def visit_select(self, select):
+        for c in select.c:
+            c.accept_visitor(self)
+    def visit_compound_select(self, select):
+        for c in select.c:
+            c.accept_visitor(self)
+    def visit_alias(self, alias):
+        for c in alias.c:
+            c.accept_visitor(self)
+        
 class Executor(object):
     """Represent a *thing that can produce Compiled objects and execute them*."""
 
@@ -1041,20 +1056,25 @@ class FromClause(Selectable):
             self._oid_column = self._locate_oid_column()
         return self._oid_column
 
-    def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_exact=False):
+    def _get_all_embedded_columns(self):
+        ret = []
+        class FindCols(VisitColumnMixin, ClauseVisitor):
+            def visit_column(self, col):
+                ret.append(col)
+        self.accept_visitor(FindCols())
+        return ret
+
+    def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False):
         """Given a ``ColumnElement``, return the ``ColumnElement``
         object from this ``Selectable`` which corresponds to that
         original ``Column`` via a proxy relationship.
         """
 
-        if require_exact:
-            if self.columns.get(column.name) is column:
-                return column
+        if require_embedded and column not in util.Set(self._get_all_embedded_columns()):
+            if not raiseerr:
+                return None
             else:
-                if not raiseerr:
-                    return None
-                else:
-                    raise exceptions.InvalidRequestError("Column instance '%s' is not directly present in table '%s'" % (str(column), str(column.table)))
+                raise exceptions.InvalidRequestError("Column instance '%s' is not directly present within selectable '%s'" % (str(column), column.table))
         for c in column.orig_set:
             try:
                 return self.original_columns[c]
index 3eb4b6d06cc130f18ee02981fa6eff5996250ae2..70fc85702e198e28e10adb533bbfdd92719a36bd 100644 (file)
@@ -178,7 +178,7 @@ class Aliasizer(AbstractClauseProcessor):
 
 class ClauseAdapter(AbstractClauseProcessor):
     """Given a clause (like as in a WHERE criterion), locate columns
-    which *correspond* to a given selectable, and changes those
+    which are embedded within a given selectable, and changes those
     columns to be that of the selectable.
 
     E.g.::
@@ -219,10 +219,10 @@ class ClauseAdapter(AbstractClauseProcessor):
         if self.exclude is not None:
             if col in self.exclude:
                 return None
-        newcol = self.selectable.corresponding_column(col, raiseerr=False, keys_ok=False)
+        newcol = self.selectable.corresponding_column(col, raiseerr=False, require_embedded=True, keys_ok=False)
         if newcol is None and self.equivalents is not None and col in self.equivalents:
             for equiv in self.equivalents[col]:
-                newcol = self.selectable.corresponding_column(equiv, raiseerr=False, keys_ok=False)
+                newcol = self.selectable.corresponding_column(equiv, raiseerr=False, require_embedded=True, keys_ok=False)
                 if newcol:
                     return newcol
         return newcol
index ab29fdf07d2d5e81b334c658feb5d677d3e943cf..5012b5e4d15ef0b3a1cef8746eae611dbfc9c9b1 100644 (file)
@@ -499,6 +499,62 @@ class RelationTest4(testbase.ORMTest):
         assert a1 not in sess
         assert b1 not in sess
 
+class RelationTest5(testbase.ORMTest):
+    """test a map to a select that relates to a map to the table"""
+    def define_tables(self, metadata):
+        global items
+        items = Table('items', metadata,
+            Column('item_policy_num', String(10), primary_key=True, key='policyNum'),
+            Column('item_policy_eff_date', Date, primary_key=True, key='policyEffDate'),
+            Column('item_type', String(20), primary_key=True, key='type'),
+            Column('item_id', Integer, primary_key=True, key='id'),
+        )
+
+    def test_basic(self):
+        class Container(object):pass
+        class LineItem(object):pass
+        
+        container_select = select(
+            [items.c.policyNum, items.c.policyEffDate, items.c.type],
+            distinct=True,
+            ).alias('container_select')
+
+        mapper(LineItem, items)
+
+        mapper(Container, container_select, order_by=asc(container_select.c.type), properties=dict(
+            lineItems = relation(LineItem, lazy=True, cascade='all, delete-orphan', order_by=asc(items.c.type),
+                primaryjoin=and_(
+                    container_select.c.policyNum==items.c.policyNum,
+                    container_select.c.policyEffDate==items.c.policyEffDate,
+                    container_select.c.type==items.c.type
+                ),
+                foreign_keys=[
+                    items.c.policyNum,
+                    items.c.policyEffDate,
+                    items.c.type,
+                ],
+            )
+        ))
+        session = create_session()
+        con = Container()
+        con.policyNum = "99"
+        con.policyEffDate = datetime.date.today()
+        con.type = "TESTER"
+        session.save(con)
+        for i in range(0, 10):
+            li = LineItem()
+            li.id = i
+            con.lineItems.append(li)
+            session.save(li)
+        session.flush()
+        session.clear()
+        newcon = session.query(Container).selectfirst()
+        assert con.policyNum == newcon.policyNum
+        assert len(newcon.lineItems) == 10
+        for old, new in zip(con.lineItems, newcon.lineItems):
+            assert old.id == new.id
+        
+        
 class TypeMatchTest(testbase.ORMTest):
     """test errors raised when trying to add items whose type is not handled by a relation"""
     def define_tables(self, metadata):