]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
more "column targeting" enhancements..columns have a "depth" from their ultimate...
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 10 Jul 2007 06:51:58 +0000 (06:51 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 10 Jul 2007 06:51:58 +0000 (06:51 +0000)
CHANGES
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/orm/eagertest3.py
test/orm/mapper.py
test/sql/selectable.py

diff --git a/CHANGES b/CHANGES
index 606d9baba659713ee38c063b7ddf16ad8482cbfc..b281a8fb6c813e7e7d6154017735c49bcb55a760 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -39,6 +39,9 @@
     - DynamicMetaData has been renamed to ThreadLocalMetaData.  the
       DynamicMetaData name is deprecated and is an alias for ThreadLocalMetaData
       or a regular MetaData if threadlocal=False
+    - some enhancements to "column targeting", the ability to match a column
+      to a "corresponding" column in another selectable.  this affects mostly
+      ORM ability to map to complex joins
     - MetaData and all SchemaItems are safe to use with pickle.  slow
       table reflections can be dumped into a pickled file to be reused later.
       Just reconnect the engine to the metadata after unpickling. [ticket:619]
index 923dd679770ed884df4b513ae87bc490de1aa76c..3b3b9b7ed57408838039ef58f6fb2193bca4b671 100644 (file)
@@ -89,8 +89,8 @@ class TranslatingDict(dict):
 
     def __translate_col(self, col):
         ourcol = self.selectable.corresponding_column(col, keys_ok=False, raiseerr=False)
-        #if col is not ourcol:
-        #    print "TD TRANSLATING ", col, "TO", ourcol
+#        if col is not ourcol and ourcol is not None:
+#            print "TD TRANSLATING ", col, "TO", ourcol
         if ourcol is None:
             return col
         else:
index eb7eb8c1d9ded517bf55667d38c9b446fcdf10e9..d7d728f2b62d6f9e3c5e75c8bd473a743f784c34 100644 (file)
@@ -600,11 +600,11 @@ class Column(SchemaItem, sql._ColumnClause):
         This is a copy of this ``Column`` referenced by a different parent
         (such as an alias or select statement).
         """
-
         fk = [ForeignKey(f._colspec) for f in self.foreign_keys]
         c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, *fk)
         c.table = selectable
         c.orig_set = self.orig_set
+        c._source_column = self
         c.__originating_column = self.__originating_column
         if not c._is_oid:
             selectable.columns.add(c)
index c86fc561a6b937fb5ceee4411bad5b0b33640cbf..2a22a40c131751540650072dbebf6a84466f4c52 100644 (file)
@@ -1542,7 +1542,17 @@ class ColumnElement(Selectable, _CompareMixin):
                 return True
         else:
             return False
-
+    
+    def _distance(self, othercolumn):
+        c = othercolumn
+        count = 0
+        while c is not self:
+            c = c._source_column
+            if c is None:
+                return -1
+            count += 1
+        return count
+        
     def _make_proxy(self, selectable, name=None):
         """Create a new ``ColumnElement`` representing this
         ``ColumnElement`` as it appears in the select list of a
@@ -1695,7 +1705,7 @@ class FromClause(Selectable):
         """
         if column in self.c:
             return column
-            
+        
         if require_embedded and column not in util.Set(self._get_all_embedded_columns()):
             if not raiseerr:
                 return None
@@ -1757,9 +1767,9 @@ class FromClause(Selectable):
         for co in self._adjusted_exportable_columns():
             cp = self._proxy_column(co)
             for ci in cp.orig_set:
-                # note that some ambiguity is raised here, whereby a selectable might have more than 
-                # one column that maps to an "original" column.  examples include unions and joins
-                self._orig_cols[ci] = cp
+                cx = self._orig_cols.get(ci)
+                if cx is None or ci._distance(cp) < ci._distance(cx):
+                    self._orig_cols[ci] = cp
         if self.oid_column is not None:
             for ci in self.oid_column.orig_set:
                 self._orig_cols[ci] = self.oid_column
@@ -2078,7 +2088,8 @@ class _Cast(ColumnElement):
         self.type = sqltypes.to_instance(totype)
         self.clause = clause
         self.typeclause = _TypeClause(self.type)
-
+        self._source_column = None
+        
     def get_children(self, **kwargs):
         return self.clause, self.typeclause
     def accept_visitor(self, visitor):
@@ -2090,6 +2101,7 @@ class _Cast(ColumnElement):
     def _make_proxy(self, selectable, name=None):
         if name is not None:
             co = _ColumnClause(name, selectable, type=self.type)
+            co._source_column = self
             co.orig_set = self.orig_set
             selectable.columns[name]= co
             return co
@@ -2512,6 +2524,7 @@ class _ColumnClause(ColumnElement):
         self.table = selectable
         self.type = sqltypes.to_instance(type)
         self._is_oid = _is_oid
+        self._source_column = None
         self.__label = None
         self.case_sensitive = case_sensitive
         self.is_literal = is_literal
@@ -2571,6 +2584,7 @@ class _ColumnClause(ColumnElement):
         is_literal = self.is_literal and (name is None or name == self.name)
         c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type=self.type, is_literal=is_literal)
         c.orig_set = self.orig_set
+        c._source_column = self
         if not self._is_oid:
             selectable.columns[c.name] = c
         return c
index a731581d55ad70a53c6ea082688579ba8a599165..8e77358124dcceaa6f36c03fb31f1cc085427c32 100644 (file)
@@ -415,7 +415,101 @@ class EagerTest5(testbase.ORMTest):
         # object is not in the session; therefore the lazy load cant trigger here,
         # eager load had to succeed
         assert len([c for c in d2.comments]) == 1
+
+class EagerTest6(testbase.ORMTest):
+    def define_tables(self, metadata):
+        global project_t, task_t, task_status_t, task_type_t, message_t, message_type_t
         
-    
+        project_t = Table('prj', metadata,
+                          Column('id',            Integer,      primary_key=True),
+                          Column('created',       DateTime ,    ),
+                          Column('title',         Unicode(100)),
+                          )
+
+        task_t = Table('task', metadata,
+                          Column('id',            Integer,      primary_key=True),
+                          Column('status_id',     Integer,      ForeignKey('task_status.id'), nullable=False),
+                          Column('title',         Unicode(100)),
+                          Column('task_type_id',  Integer ,     ForeignKey('task_type.id'), nullable=False),
+                          Column('prj_id',        Integer ,     ForeignKey('prj.id'), nullable=False),
+                          )
+
+        task_status_t = Table('task_status', metadata,
+                                Column('id',                Integer,      primary_key=True),
+                                )
+
+        task_type_t = Table('task_type', metadata,
+                            Column('id',   Integer,    primary_key=True),
+                            )
+
+        message_t  = Table('msg', metadata,
+                            Column('id', Integer,  primary_key=True),
+                            Column('posted',    DateTime, index=True,),
+                            Column('type_id',   Integer, ForeignKey('msg_type.id')),
+                            Column('task_id',   Integer, ForeignKey('task.id')),
+                            )
+
+        message_type_t = Table('msg_type', metadata,
+                                Column('id',                Integer,      primary_key=True),
+                                Column('name',              Unicode(20)),
+                                Column('display_name',      Unicode(20)),
+                                )
+
+    def setUp(self):
+        testbase.db.execute("INSERT INTO prj (title) values('project 1');")
+        testbase.db.execute("INSERT INTO task_status (id) values(1);")
+        testbase.db.execute("INSERT INTO task_type(id) values(1);")
+        testbase.db.execute("INSERT INTO task (title, task_type_id, status_id, prj_id) values('task 1',1,1,1);")
+            
+    def test_nested_joins(self):
+        # this is testing some subtle column resolution stuff,
+        # concerning corresponding_column() being extremely accurate
+        # as well as how mapper sets up its column properties
+        
+        class Task(object):pass
+        class Task_Type(object):pass
+        class Message(object):pass
+        class Message_Type(object):pass
+
+        tsk_cnt_join = outerjoin(project_t, task_t, task_t.c.prj_id==project_t.c.id)
+
+        ss = select([project_t.c.id.label('prj_id'), func.count(task_t.c.id).label('tasks_number')], 
+                    from_obj=[tsk_cnt_join], group_by=[project_t.c.id]).alias('prj_tsk_cnt_s')
+        j = join(project_t, ss, project_t.c.id == ss.c.prj_id)
+
+        mapper(Task_Type, task_type_t)
+
+        mapper( Task, task_t,
+                              properties=dict(type=relation(Task_Type, lazy=False),
+                                             ))
+
+        mapper(Message_Type, message_type_t)
+
+        mapper(Message, message_t, 
+                         properties=dict(type=relation(Message_Type, lazy=False, uselist=False),
+                                         ))
+
+        tsk_cnt_join = outerjoin(project_t, task_t, task_t.c.prj_id==project_t.c.id)
+        ss = select([project_t.c.id.label('prj_id'), func.count(task_t.c.id).label('tasks_number')], 
+                    from_obj=[tsk_cnt_join], group_by=[project_t.c.id]).alias('prj_tsk_cnt_s')
+        j = join(project_t, ss, project_t.c.id == ss.c.prj_id)
+
+        j  = outerjoin( task_t, message_t, task_t.c.id==message_t.c.task_id)
+        jj = select([ task_t.c.id.label('task_id'),
+                      func.count(message_t.c.id).label('props_cnt')],
+                      from_obj=[j], group_by=[task_t.c.id]).alias('prop_c_s')
+        jjj = join(task_t, jj, task_t.c.id == jj.c.task_id)
+
+        class cls(object):pass
+
+        props =dict(type=relation(Task_Type, lazy=False))
+        print [c.key for c in jjj.c]
+        cls.mapper = mapper( cls, jjj, properties=props)
+
+        session = create_session()
+
+        for t in session.query(cls.mapper).limit(10).offset(0).list():
+            print t.id, t.title, t.props_cnt        
+
 if __name__ == "__main__":    
     testbase.main()
index 63af53b967ed0799c9fbc25d95b42dc5fe121938..197f988b7aeff89df8e2cdd221cb0b72f31da928 100644 (file)
@@ -497,7 +497,7 @@ class MapperTest(MapperSuperTest):
             class_mapper(User)
         except exceptions.ArgumentError, e:
             assert str(e) == "Column '%s' is not represented in mapper's table.  Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(f)
-            clear_mappers()
+        clear_mappers()
         
         mapper(User, users, properties={
             'concat': column_property(f),
index 57ad998860d6441c2a1f3490b28232c2303be2af..340c55837d5b0d51d38b8b72dd9da878dac047ac 100755 (executable)
@@ -27,17 +27,42 @@ table2 = Table('table2', db,
 )\r
 \r
 class SelectableTest(testbase.AssertMixin):\r
+    def testdistance(self):\r
+        s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')])\r
+\r
+        # didnt do this yet...col.label().make_proxy() has same "distance" as col.make_proxy() so far\r
+        #assert s.corresponding_column(table.c.col1) is s.c.col1\r
+        assert s.corresponding_column(s.c.col1) is s.c.col1\r
+        assert s.corresponding_column(s.c.c1) is s.c.c1\r
+        \r
     def testjoinagainstself(self):\r
         jj = select([table.c.col1.label('bar_col1')])\r
         jjj = join(table, jj, table.c.col1==jj.c.bar_col1)\r
+        \r
+        # test column directly agaisnt itself\r
         assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1\r
 \r
+        assert jjj.corresponding_column(jj.c.bar_col1) is jjj.c.bar_col1\r
+        \r
+        # test alias of the join, targets the column with the least \r
+        # "distance" between the requested column and the returned column\r
+        # (i.e. there is less indirection between j2.c.table1_col1 and table.c.col1, than\r
+        # there is from j2.c.bar_col1 to table.c.col1)\r
+        j2 = jjj.alias('foo')\r
+        assert j2.corresponding_column(table.c.col1) is j2.c.table1_col1\r
+        \r
+\r
     def testjoinagainstjoin(self):\r
         j  = outerjoin(table, table2, table.c.col1==table2.c.col2)\r
         jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo')\r
         jjj = join(table, jj, table.c.col1==jj.c.bar_col1)\r
         assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1\r
+\r
+        j2 = jjj.alias('foo')\r
+        print j2.corresponding_column(jjj.c.table1_col1)\r
+        assert j2.corresponding_column(jjj.c.table1_col1) is j2.c.table1_col1\r
         \r
+        assert jjj.corresponding_column(jj.c.bar_col1) is jj.c.bar_col1\r
         \r
     def testtablealias(self):\r
         a = table.alias('a')\r
@@ -110,8 +135,8 @@ class SelectableTest(testbase.AssertMixin):
         j = join(a, table2)\r
         \r
         criterion = a.c.col1 == table2.c.col2\r
-        print\r
-        print str(j)\r
+        print criterion\r
+        print j.onclause\r
         self.assert_(criterion.compare(j.onclause))\r
 \r
     def testselectlabels(self):\r