]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fixed glitch in Select visit traversal, fixes #693
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 Jul 2007 16:46:11 +0000 (16:46 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 Jul 2007 16:46:11 +0000 (16:46 +0000)
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/sql.py
test/orm/alltests.py
test/orm/selectable.py [new file with mode: 0644]
test/sql/generative.py

index 76cc412890e46d9b008f0222997a2d3fd27ea21b..92b186012ac3b41bbfd0c7bec219a722cc46d4b2 100644 (file)
@@ -407,6 +407,9 @@ class Mapper(object):
         # may be a join or other construct
         self.tables = sqlutil.TableFinder(self.mapped_table)
 
+        if not len(self.tables):
+            raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table))
+
         # determine primary key columns
         self.pks_by_table = {}
 
index 01588e92da1a98f32347fa2bcdff00a44f7974fd..ff92f0b430044f0e8544edfa8ddeda1f18a429ea 100644 (file)
@@ -3146,7 +3146,7 @@ class Select(_SelectBaseMixin, FromClause):
 
     def get_children(self, column_collections=True, **kwargs):
         return (column_collections and list(self.columns) or []) + \
-            list(self._froms) + \
+            list(self.locate_all_froms()) + \
             [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
 
     def _recorrelate_froms(self, froms):
index 4f8f4b6b7a03ef27fd3d303218fa5d9055bfddb5..9fcea88590861101ffc45f3c9c0eb6d441b8d9fe 100644 (file)
@@ -11,6 +11,7 @@ def suite():
         'orm.lazy_relations',
         'orm.eager_relations',
         'orm.mapper',
+        'orm.selectable',
         'orm.collection',
         'orm.generative',
         'orm.lazytest1',
diff --git a/test/orm/selectable.py b/test/orm/selectable.py
new file mode 100644 (file)
index 0000000..920cd9d
--- /dev/null
@@ -0,0 +1,49 @@
+"""all tests involving generic mapping to Select statements"""
+
+import testbase
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+from fixtures import *
+from query import QueryTest
+
+class SelectableNoFromsTest(ORMTest):
+    def define_tables(self, metadata):
+        global common_table
+        common_table = Table('common', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', Integer),
+            Column('extra', String(45)),
+        )
+    
+    def test_no_tables(self):
+        class Subset(object):
+            pass
+        selectable = select(["x", "y", "z"]).alias('foo')
+        try:
+            mapper(Subset, selectable)
+            compile_mappers()
+            assert False
+        except exceptions.InvalidRequestError, e:
+            assert str(e) == "Could not find any Table objects in mapped table 'SELECT x, y, z'", str(e)
+            
+    def test_basic(self):
+        class Subset(Base):
+            pass
+
+        subset_select = select([common_table.c.id, common_table.c.data]).alias('subset')
+        subset_mapper = mapper(Subset, subset_select)
+
+        sess = create_session(bind=testbase.db)
+        l = Subset()
+        l.data = 1
+        sess.save(l)
+        sess.flush()
+        sess.clear()
+
+        assert [Subset(data=1)] == sess.query(Subset).all()
+
+    # TODO: more tests mapping to selects
+    
+if __name__ == '__main__':
+    testbase.main()
\ No newline at end of file
index 357a66fcdfc16e478b84f59246eb0ff1adcb947e..80a18d49798d961fd9c791e48615b9da3b045c8f 100644 (file)
@@ -166,10 +166,9 @@ class ClauseTest(selecttests.SQLTest):
         assert str(clause2) == str(t1.join(t2, t1.c.col2==t2.c.col3))
     
     def test_select(self):
-        s = t1.select()
-        s2 = select([s])
+        s2 = select([t1])
         s2_assert = str(s2)
-        s3_assert = str(select([t1.select()], t1.c.col2==7))
+        s3_assert = str(select([t1], t1.c.col2==7))
         class Vis(ClauseVisitor):
             def visit_select(self, select):
                 select.append_whereclause(t1.c.col2==7)
@@ -183,7 +182,7 @@ class ClauseTest(selecttests.SQLTest):
 
         print "------------------"
         
-        s4_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col3==9)))
+        s4_assert = str(select([t1], and_(t1.c.col2==7, t1.c.col3==9)))
         class Vis(ClauseVisitor):
             def visit_select(self, select):
                 select.append_whereclause(t1.c.col3==9)
@@ -194,7 +193,7 @@ class ClauseTest(selecttests.SQLTest):
         assert str(s3) == s3_assert
         
         print "------------------"
-        s5_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col1==9)))
+        s5_assert = str(select([t1], and_(t1.c.col2==7, t1.c.col1==9)))
         class Vis(ClauseVisitor):
             def visit_binary(self, binary):
                 if binary.left is t1.c.col3: