]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- _Label adds itself to the proxy collection so that it works in correspoinding colum...
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 May 2008 00:55:49 +0000 (00:55 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 May 2008 00:55:49 +0000 (00:55 +0000)
- this partially fixes some issues in [ticket:1022] but leaving the "unlabeled" fix for 0.5 for now

lib/sqlalchemy/sql/expression.py
test/orm/eager_relations.py
test/sql/selectable.py

index 269d31661eb34bc0d76604e7fa77f73e1df68190..867fdd69c3ce5e6be7d2a14957c31b5a456edf19 100644 (file)
@@ -1733,7 +1733,10 @@ class FromClause(Selectable):
 
         col, intersect = None, None
         target_set = column.proxy_set
-        for c in self.c + [self.oid_column]:
+        cols = self.c
+        if self.oid_column:
+            cols += [self.oid_column]
+        for c in cols:
             i = c.proxy_set.intersection(target_set)
             if i and \
                 (not require_embedded or c.proxy_set.issuperset(target_set)) and \
@@ -2553,9 +2556,11 @@ class _Label(ColumnElement):
 
     def _make_proxy(self, selectable, name = None):
         if isinstance(self.obj, (Selectable, ColumnElement)):
-            return self.obj._make_proxy(selectable, name=self.name)
+            e = self.obj._make_proxy(selectable, name=self.name)
         else:
-            return column(self.name)._make_proxy(selectable=selectable)
+            e = column(self.name)._make_proxy(selectable=selectable)
+        e.proxies.append(self)
+        return e
 
 class _ColumnClause(ColumnElement):
     """Represents a generic column expression from any textual string.
index 866eec718e28abdc2ee35fa9accbb814fe5fbfc8..418df83dda7a8c4fca046ccf0984a8ab3218953c 100644 (file)
@@ -1047,14 +1047,16 @@ class SubqueryTest(ORMTest):
                     self.assertEquals(user.query_score, user.prop_score)
             self.assert_sql_count(testing.db, go, 1)
 
-            u = session.query(User).filter_by(name='joe').one()
-            self.assertEquals(u.query_score, u.prop_score)
-
-            # fails:
-            #def go():
-            #    u = session.query(User).filter_by(name='joe').one()
-            #    self.assertEquals(u.query_score, u.prop_score)
-            #self.assert_sql_count(testing.db, go, 1)
+
+            # fails for non labeled (fixed in 0.5):
+            if labeled:
+                def go():
+                    u = session.query(User).filter_by(name='joe').one()
+                    self.assertEquals(u.query_score, u.prop_score)
+                self.assert_sql_count(testing.db, go, 1)
+            else:
+                u = session.query(User).filter_by(name='joe').one()
+                self.assertEquals(u.query_score, u.prop_score)
             
             for t in (tags_table, users_table):
                 t.delete().execute()
index d3b639767477bf83c6a35117338f421d39bab9ea..b29ba8d5c0c30d21b7b7639fb6cd2d84f283ff68 100755 (executable)
@@ -174,6 +174,15 @@ class SelectableTest(TestBase, AssertsExecutionResults):
         print str(j)
         self.assert_(criterion.compare(j.onclause))
 
+    def test_labeled_select_correspoinding(self):
+        l1 = select([func.max(table.c.col1)]).label('foo')
+
+        s = select([l1])
+        assert s.corresponding_column(l1).name == s.c.foo
+
+        s = select([table.c.col1, l1])
+        assert s.corresponding_column(l1).name == s.c.foo
+
     def testselectaliaslabels(self):
         a = table2.select(use_labels=True).alias('a')
         print str(a.select())