]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- select_table mapper turns straight join into aliased select + custom PK, to allow
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Jan 2008 02:34:17 +0000 (02:34 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Jan 2008 02:34:17 +0000 (02:34 +0000)
joins onto select_table mappers
- starting a generalized reduce_columns func

lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
test/orm/inheritance/query.py

index c733c68ad2a7b131847305087d3ed8fb0dd919a4..61f5a65791215d27a169fa83353bd4809db30250 100644 (file)
@@ -443,6 +443,10 @@ class Mapper(object):
             # multiple columns that all reference a common parent column.  it will also resolve the column
             # against the "mapped_table" of this mapper.
 
+            # TODO !!!
+            #primary_key = sqlutil.reduce_columns((self.primary_key_argument or self._pks_by_table[self.mapped_table]))
+
+            # TODO !!! remove all this
             primary_key = expression.ColumnSet()
 
             for col in (self.primary_key_argument or self._pks_by_table[self.mapped_table]):
@@ -724,7 +728,17 @@ class Mapper(object):
         """
 
         if self.select_table is not self.mapped_table:
-            self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, _polymorphic_map=self.polymorphic_map, polymorphic_on=_corresponding_column_or_error(self.select_table, self.polymorphic_on), primary_key=self.primary_key_argument)
+            # turn a straight join into an aliased selectable
+            if isinstance(self.select_table, sql.Join):
+                if self.primary_key_argument:
+                    primary_key_arg = self.primary_key_argument
+                else:
+                    primary_key_arg = self.select_table.primary_key
+                self.select_table = self.select_table.select(use_labels=True).alias()
+            else:
+                primary_key_arg = self.primary_key_argument
+
+            self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, _polymorphic_map=self.polymorphic_map, polymorphic_on=_corresponding_column_or_error(self.select_table, self.polymorphic_on), primary_key=primary_key_arg)
             adapter = sqlutil.ClauseAdapter(self.select_table, equivalents=self.__surrogate_mapper._equivalent_columns)
             
             if self.order_by:
index 4eb555d4d4fc5ae5df36c8999cac3754da3fa506..be870ee792701bb5e59b99707e2962c861c6250a 100644 (file)
@@ -2178,11 +2178,19 @@ class Join(FromClause):
         self._foreign_keys = util.Set()
 
         columns = list(self._flatten_exportable_columns())
+
+        #global sql_util
+        #if not sql_util:
+        #    from sqlalchemy.sql import util as sql_util
+        #self._primary_key = sql_util.reduce_columns([c for c in columns if c.primary_key], self.onclause)
+
         self.__init_primary_key(columns)
+
         for co in columns:
             cp = self._proxy_column(co)
 
     def __init_primary_key(self, columns):
+        # TODO !!! remove all this
         global schema
         if schema is None:
             from sqlalchemy import schema
index c2ac26557ee06c96dcdcc124323f63a07194b9c8..0989cb43e9d9a358f565b3ee956f46917a7c56e6 100644 (file)
@@ -1,5 +1,6 @@
 from sqlalchemy import util, schema, topological
-from sqlalchemy.sql import expression, visitors
+from sqlalchemy.sql import expression, visitors, operators
+from itertools import chain
 
 """Utility functions that build upon SQL and Schema constructs."""
 
@@ -49,6 +50,33 @@ def find_columns(clause):
     visitors.traverse(clause, visit_column=visit_column)
     return cols
     
+    
+def reduce_columns(columns, *clauses):
+    raise NotImplementedError()
+    
+    # TODO !!!
+    all_proxied_cols = util.Set(chain(*[c.proxy_set for c in columns]))
+    
+    columns = util.Set(columns)
+    
+    equivs = {}
+    for col in columns:
+        for fk in col.foreign_keys:
+            if fk.column in all_proxied_cols:
+                for c in columns:
+                    if col.references(c):
+                        equivs[col] = c
+    
+    if clauses:
+        def visit_binary(binary):
+            if binary.operator == operators.eq and binary.left in columns and binary.right in columns:
+                equivs[binary.left] = binary.right
+        for clause in clauses:
+            visitors.traverse(clause, visit_binary=visit_binary)
+    
+    result = util.Set([c for c in columns if c not in equivs])
+    return expression.ColumnSet(result)
+
 class ColumnsInClause(visitors.ClauseVisitor):
     """Given a selectable, visit clauses and determine if any columns
     from the clause are in the selectable.
index 2e3d392556b7d159a72a5eef864c657f2589bd88..b3239d3b3ae472a04823dfeae0d998ee31aedc7a 100644 (file)
@@ -187,16 +187,13 @@ def make_test(select_type):
                 self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
     
         def test_join_to_polymorphic(self):
-            if select_type == 'Joins':
-                return
-                
             sess = create_session()
             self.assertEquals(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2)
 
             self.assertEquals(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2)
     
         def test_join_to_subclass(self):
-            if select_type in ('Joins', ''):
+            if select_type == '':
                 return
 
             sess = create_session()
@@ -214,8 +211,6 @@ def make_test(select_type):
             self.assertEquals(sess.query(Company).join(['employees', Engineer.machines]).filter(Machine.name.ilike("%thinkpad%")).all(), [c1])
         
         def test_join_through_polymorphic(self):
-            if select_type == 'Joins':
-                return
 
             sess = create_session()