]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
finally, a really straightforward reduce() method which reduces cols
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Jan 2008 17:59:27 +0000 (17:59 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Jan 2008 17:59:27 +0000 (17:59 +0000)
to the minimal set for every test case I can come up with, and
now replaces all the cruft in Mapper._compile_pks() as well as
Join.__init_primary_key().  mappers can now handle aliased selects
and figure out the correct PKs pretty well [ticket:933]

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
test/orm/inheritance/query.py
test/orm/mapper.py
test/sql/selectable.py

diff --git a/CHANGES b/CHANGES
index a940248185790b23e5dd6940e7075c0c4d80b3eb..b555603b7b0357e9825a66cc40168e36155f59cb 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -30,7 +30,14 @@ CHANGES
     - general improvements to the behavior of join() in 
       conjunction with polymorphic mappers, i.e. joining
       from/to polymorphic mappers and properly applying 
-      aliases
+      aliases.  
+      
+    - fixed/improved behavior when a mapper determines the
+      natural "primary key" of a mapped join, it will more 
+      effectively reduce columns which are equivalent via
+      foreign key relation.  This affects how many arguments 
+      need to be sent to query.get(), among other things.
+      [ticket:933]
       
     - fixed bug in polymorphic inheritance which made it 
       difficult to set a working "order_by" on a polymorphic
index 61f5a65791215d27a169fa83353bd4809db30250..07075efd09077aa11166ed95eb7a5fd87105f84a 100644 (file)
@@ -418,6 +418,7 @@ class Mapper(object):
         all_cols = util.Set(chain(*[c2 for c2 in [col.proxy_set for col in [c for c in self._columntoproperty]]]))
         pk_cols = util.Set([c for c in all_cols if c.primary_key])
 
+        # identify primary key columns which are also mapped by this mapper.
         for t in util.Set(self.tables + [self.mapped_table]):
             self._all_tables.add(t)
             if t.primary_key and pk_cols.issuperset(t.primary_key):
@@ -425,6 +426,7 @@ class Mapper(object):
                 self._pks_by_table[t] = util.OrderedSet(t.primary_key).intersection(pk_cols)
             self._cols_by_table[t] = util.OrderedSet(t.c).intersection(all_cols)
 
+        # if explicit PK argument sent, add those columns to the primary key mappings
         if self.primary_key_argument:
             for k in self.primary_key_argument:
                 if k.table not in self._pks_by_table:
@@ -432,58 +434,22 @@ class Mapper(object):
                 self._pks_by_table[k.table].add(k)
 
         if self.mapped_table not in self._pks_by_table or len(self._pks_by_table[self.mapped_table]) == 0:
-            raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
+            raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description))
 
         if self.inherits is not None and not self.concrete and not self.primary_key_argument:
+            # if inheriting, the "primary key" for this mapper is that of the inheriting (unless concrete or explicit)
             self.primary_key = self.inherits.primary_key
             self._get_clause = self.inherits._get_clause
         else:
-            # create the "primary_key" for this mapper.  this will flatten "equivalent" primary key columns
-            # into one column, where "equivalent" means that one column references the other via foreign key, or
-            # multiple columns that all reference a common parent column.  it will also resolve the column
-            # against the "mapped_table" of this mapper.
-
-            # TODO !!!
-            #primary_key = sqlutil.reduce_columns((self.primary_key_argument or self._pks_by_table[self.mapped_table]))
-
-            # TODO !!! remove all this
-            primary_key = expression.ColumnSet()
-
-            for col in (self.primary_key_argument or self._pks_by_table[self.mapped_table]):
-                c = self.mapped_table.corresponding_column(col)
-                if c is None:
-                    for cc in self._equivalent_columns[col]:
-                        c = self.mapped_table.corresponding_column(cc)
-                        if c is not None:
-                            break
-                    else:
-                        raise exceptions.ArgumentError("Cant resolve column " + str(col))
-
-                # this step attempts to resolve the column to an equivalent which is not
-                # a foreign key elsewhere.  this helps with joined table inheritance
-                # so that PKs are expressed in terms of the base table which is always
-                # present in the initial select
-                # TODO: this is a little hacky right now, the "tried" list is to prevent
-                # endless loops between cyclical FKs, try to make this cleaner/work better/etc.,
-                # perhaps via topological sort (pick the leftmost item)
-                tried = util.Set()
-                while True:
-                    if not len(c.foreign_keys) or c in tried:
-                        break
-                    for cc in c.foreign_keys:
-                        cc = cc.column
-                        c2 = self.mapped_table.corresponding_column(cc)
-                        if c2 is not None:
-                            c = c2
-                            tried.add(c)
-                            break
-                    else:
-                        break
-                primary_key.add(c)
+            # determine primary key from argument or mapped_table pks - reduce to the minimal set of columns
+            if self.primary_key_argument:
+                primary_key = sqlutil.reduce_columns([self.mapped_table.corresponding_column(c) for c in self.primary_key_argument])
+            else:
+                primary_key = sqlutil.reduce_columns(self._pks_by_table[self.mapped_table])
 
             if len(primary_key) == 0:
-                raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
-
+                raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description))
+            
             self.primary_key = primary_key
             self.__log("Identified primary key columns: " + str(primary_key))
 
@@ -730,15 +696,9 @@ class Mapper(object):
         if self.select_table is not self.mapped_table:
             # turn a straight join into an aliased selectable
             if isinstance(self.select_table, sql.Join):
-                if self.primary_key_argument:
-                    primary_key_arg = self.primary_key_argument
-                else:
-                    primary_key_arg = self.select_table.primary_key
                 self.select_table = self.select_table.select(use_labels=True).alias()
-            else:
-                primary_key_arg = self.primary_key_argument
 
-            self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, _polymorphic_map=self.polymorphic_map, polymorphic_on=_corresponding_column_or_error(self.select_table, self.polymorphic_on), primary_key=primary_key_arg)
+            self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, _polymorphic_map=self.polymorphic_map, polymorphic_on=_corresponding_column_or_error(self.select_table, self.polymorphic_on), primary_key=self.primary_key_argument)
             adapter = sqlutil.ClauseAdapter(self.select_table, equivalents=self.__surrogate_mapper._equivalent_columns)
             
             if self.order_by:
index b817d29bc8d54fd95919f4f907a3b1933b9ff28a..0d58a52161b59cb12fced954117186859ac44329 100644 (file)
@@ -572,7 +572,7 @@ class Session(object):
         This is equivalent to calling ``expunge()`` for all objects in
         this ``Session``.
         """
-
+        
         for instance in self:
             self._unattach(instance)
         self.uow = unitofwork.UnitOfWork(self)
index be870ee792701bb5e59b99707e2962c861c6250a..3ebc4960fac9349c47ced5ac109d8baa9330babe 100644 (file)
@@ -31,7 +31,7 @@ from sqlalchemy import util, exceptions
 from sqlalchemy.sql import operators, visitors
 from sqlalchemy import types as sqltypes
 
-functions, schema = None, None
+functions, schema, sql_util = None, None, None
 DefaultDialect, ClauseAdapter = None, None
 
 __all__ = [
@@ -2179,51 +2179,14 @@ class Join(FromClause):
 
         columns = list(self._flatten_exportable_columns())
 
-        #global sql_util
-        #if not sql_util:
-        #    from sqlalchemy.sql import util as sql_util
-        #self._primary_key = sql_util.reduce_columns([c for c in columns if c.primary_key], self.onclause)
-
-        self.__init_primary_key(columns)
+        global sql_util
+        if not sql_util:
+            from sqlalchemy.sql import util as sql_util
+        self._primary_key = sql_util.reduce_columns([c for c in columns if c.primary_key], self.onclause)
 
         for co in columns:
             cp = self._proxy_column(co)
 
-    def __init_primary_key(self, columns):
-        # TODO !!! remove all this
-        global schema
-        if schema is None:
-            from sqlalchemy import schema
-        pkcol = util.Set([c for c in columns if c.primary_key])
-
-        equivs = {}
-        def add_equiv(a, b):
-            for x, y in ((a, b), (b, a)):
-                if x in equivs:
-                    equivs[x].add(y)
-                else:
-                    equivs[x] = util.Set([y])
-
-        def visit_binary(binary):
-            if binary.operator == operators.eq and isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column):
-                add_equiv(binary.left, binary.right)
-        visitors.traverse(self.onclause, visit_binary=visit_binary)
-
-        for col in pkcol:
-            for fk in col.foreign_keys:
-                if fk.column in pkcol:
-                    add_equiv(col, fk.column)
-
-        omit = util.Set()
-        for col in pkcol:
-            p = col
-            for c in equivs.get(col, util.Set()):
-                if p.references(c) or (c.primary_key and not p.primary_key):
-                    omit.add(p)
-                    p = c
-
-        self._primary_key = ColumnSet(pkcol.difference(omit))
-
     def description(self):
         return "Join object on %s(%d) and %s(%d)" % (self.left.description, id(self.left), self.right.description, id(self.right))
     description = property(description)
@@ -2284,6 +2247,12 @@ class Join(FromClause):
         """Returns the column list of this Join with all equivalently-named,
         equated columns folded into one column, where 'equated' means they are
         equated to each other in the ON clause of this join.
+        
+        this method is used by select(fold_equivalents=True).
+        
+        The primary usage for this is when generating UNIONs so that 
+        each selectable can have distinctly-named columns without the need
+        for use_labels=True.
         """
 
         if self.__folded_equivalents is not None:
index 0989cb43e9d9a358f565b3ee956f46917a7c56e6..93998c9a91ac0c70d6c9de4469a5fd7c19ce872f 100644 (file)
@@ -52,30 +52,42 @@ def find_columns(clause):
     
     
 def reduce_columns(columns, *clauses):
-    raise NotImplementedError()
+    """given a list of columns, return a 'reduced' set based on natural equivalents.
+
+    the set is reduced to the smallest list of columns which have no natural
+    equivalent present in the list.  A "natural equivalent" means that two columns
+    will ultimately represent the same value because they are related by a foreign key.
+    
+    \*clauses is an optional list of join clauses which will be traversed
+    to further identify columns that are "equivalent".
     
-    # TODO !!!
-    all_proxied_cols = util.Set(chain(*[c.proxy_set for c in columns]))
+    This function is primarily used to determine the most minimal "primary key"
+    from a selectable, by reducing the set of primary key columns present
+    in the the selectable to just those that are not repeated.
+    
+    """
     
     columns = util.Set(columns)
     
-    equivs = {}
+    omit = util.Set()
     for col in columns:
         for fk in col.foreign_keys:
-            if fk.column in all_proxied_cols:
-                for c in columns:
-                    if col.references(c):
-                        equivs[col] = c
+            for c in columns:
+                if c is col:
+                    continue
+                if fk.column.shares_lineage(c):
+                    omit.add(col)
+                    break
     
     if clauses:
         def visit_binary(binary):
-            if binary.operator == operators.eq and binary.left in columns and binary.right in columns:
-                equivs[binary.left] = binary.right
+            cols = columns.difference(omit)
+            if binary.operator == operators.eq and binary.left in cols and binary.right in cols:
+                omit.add(binary.right)
         for clause in clauses:
             visitors.traverse(clause, visit_binary=visit_binary)
     
-    result = util.Set([c for c in columns if c not in equivs])
-    return expression.ColumnSet(result)
+    return expression.ColumnSet(columns.difference(omit))
 
 class ColumnsInClause(visitors.ClauseVisitor):
     """Given a selectable, visit clauses and determine if any columns
index b3239d3b3ae472a04823dfeae0d998ee31aedc7a..b9f11faa7c31fd351f7b6fea3f3cf38eb75935c3 100644 (file)
@@ -159,7 +159,16 @@ def make_test(select_type):
             all_employees = [e1, e2, b1, m1, e3]
             c1_employees = [e1, e2, b1, m1]
             c2_employees = [e3]
-
+            
+        def test_get(self):
+            sess = create_session()
+            
+            # for all mappers, ensure the primary key has been calculated as just the "person_id"
+            # column
+            self.assertEquals(sess.query(Person).get(e1.person_id), Engineer(name="dilbert"))
+            self.assertEquals(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert"))
+            self.assertEquals(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss"))
+            
         def test_filter_on_subclass(self):
             sess = create_session()
             self.assertEquals(sess.query(Engineer).all()[0], Engineer(name="dilbert"))
index 70cd81428d3dd7db4198677eb251b2802afecc41..a8f75a31b808a518b10b0f16d380086ce99a301d 100644 (file)
@@ -72,7 +72,7 @@ class MapperTest(MapperSuperTest):
             mapper(User, s)
             assert False
         except exceptions.ArgumentError, e:
-            assert str(e) == "Could not assemble any primary key columns for mapped table 'foo'"
+            assert "could not assemble any primary key columns for mapped table 'foo'" in str(e)
 
     def test_compileonsession(self):
         m = mapper(User, users)
index a64697b81d0e0353f9411111b1ef208dfed3814e..45bd7d823a2bdfdab408d7e258753d4da3b0e00e 100755 (executable)
@@ -5,6 +5,7 @@ every selectable unit behaving nicely with others.."""
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from testlib import *
+from sqlalchemy.sql import util as sql_util
 
 metadata = MetaData()
 table = Table('table1', metadata,
@@ -275,6 +276,124 @@ class PrimaryKeyTest(AssertMixin):
         assert str(j) == "a JOIN b ON a.id = b.id AND b.x = :b_x_1", str(j)
         assert list(j.primary_key) == [a.c.id, b.c.x]
 
+    def test_onclause_direction(self):
+        metadata = MetaData()
+
+        employee = Table( 'Employee', metadata,
+            Column('name', String(100)),
+            Column('id', Integer, primary_key= True),
+        )
+
+        engineer = Table( 'Engineer', metadata,
+            Column('id', Integer, ForeignKey( 'Employee.id', ), primary_key=True),
+        )
+
+        self.assertEquals(
+            set(employee.join(engineer, employee.c.id==engineer.c.id).primary_key),
+            set([employee.c.id])
+        )
+
+        self.assertEquals(
+            set(employee.join(engineer, engineer.c.id==employee.c.id).primary_key),
+            set([employee.c.id])
+        )
+
+
+class ReduceTest(AssertMixin):
+    def test_reduce(self):
+        meta = MetaData()
+        t1 = Table('t1', meta,
+            Column('t1id', Integer, primary_key=True),
+            Column('t1data', String(30)))
+        t2 = Table('t2', meta,
+            Column('t2id', Integer, ForeignKey('t1.t1id'), primary_key=True),
+            Column('t2data', String(30)))
+        t3 = Table('t3', meta,
+            Column('t3id', Integer, ForeignKey('t2.t2id'), primary_key=True),
+            Column('t3data', String(30)))
+        
+        
+        self.assertEquals(
+            set(sql_util.reduce_columns([t1.c.t1id, t1.c.t1data, t2.c.t2id, t2.c.t2data, t3.c.t3id, t3.c.t3data])),
+            set([t1.c.t1id, t1.c.t1data, t2.c.t2data, t3.c.t3data])
+        )
+    
+    def test_reduce_aliased_join(self):
+        metadata = MetaData()
+        people = Table('people', metadata,
+           Column('person_id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
+           Column('name', String(50)),
+           Column('type', String(30)))
+
+        engineers = Table('engineers', metadata,
+           Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
+           Column('status', String(30)),
+           Column('engineer_name', String(50)),
+           Column('primary_language', String(50)),
+          )
+     
+        managers = Table('managers', metadata,
+           Column('person_id', Integer, ForeignKey('people.person_id'), primary_key=True),
+           Column('status', String(30)),
+           Column('manager_name', String(50))
+           )
+        
+        pjoin = people.outerjoin(engineers).outerjoin(managers).select(use_labels=True).alias('pjoin')
+        self.assertEquals(
+            set(sql_util.reduce_columns([pjoin.c.people_person_id, pjoin.c.engineers_person_id, pjoin.c.managers_person_id])),
+            set([pjoin.c.people_person_id])
+        )
+        
+    def test_reduce_aliased_union(self):
+        metadata = MetaData()
+        item_table = Table(
+            'item', metadata,
+            Column('id', Integer, ForeignKey('base_item.id'), primary_key=True),
+            Column('dummy', Integer, default=0))
+
+        base_item_table = Table(
+            'base_item', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('child_name', String(255), default=None))
+        
+        from sqlalchemy.orm.util import polymorphic_union
+        
+        item_join = polymorphic_union( {
+            'BaseItem':base_item_table.select(base_item_table.c.child_name=='BaseItem'),
+            'Item':base_item_table.join(item_table),
+            }, None, 'item_join')
+            
+        self.assertEquals(
+            set(sql_util.reduce_columns([item_join.c.id, item_join.c.dummy, item_join.c.child_name])),
+            set([item_join.c.id, item_join.c.dummy, item_join.c.child_name])
+        )    
+    
+    def test_reduce_aliased_union_2(self):
+        metadata = MetaData()
+
+        page_table = Table('page', metadata,
+            Column('id', Integer, primary_key=True),
+        )
+        magazine_page_table = Table('magazine_page', metadata,
+            Column('page_id', Integer, ForeignKey('page.id'), primary_key=True),
+        )
+        classified_page_table = Table('classified_page', metadata,
+            Column('magazine_page_id', Integer, ForeignKey('magazine_page.page_id'), primary_key=True),
+        )
+        
+        from sqlalchemy.orm.util import polymorphic_union
+        pjoin = polymorphic_union(
+            {
+                'm': page_table.join(magazine_page_table),
+                'c': page_table.join(magazine_page_table).join(classified_page_table),
+            }, None, 'page_join')
+            
+        self.assertEquals(
+            set(sql_util.reduce_columns([pjoin.c.id, pjoin.c.page_id, pjoin.c.magazine_page_id])),
+            set([pjoin.c.id])
+        )    
+    
+            
 class DerivedTest(AssertMixin):
     def test_table(self):
         meta = MetaData()