]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Query.select_from(), from_statement() ensure
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 16 Nov 2008 19:33:26 +0000 (19:33 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 16 Nov 2008 19:33:26 +0000 (19:33 +0000)
that the given argument is a FromClause,
or Text/Select/Union, respectively.

- Query.add_column() can accept FromClause objects
in the same manner as session.query() can.

CHANGES
lib/sqlalchemy/orm/query.py
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index fa95094db8748e47dbe9b61c16cd4e3f23dc5b8f..6d26b26884d8ddb49592f3bc2b7d947f5f32fcf8 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -6,7 +6,15 @@ CHANGES
 =======
 0.5.0rc5
 ========
-- bugfixes
+- bugfixes, behavioral changes
+- orm
+    - Query.select_from(), from_statement() ensure
+      that the given argument is a FromClause,
+      or Text/Select/Union, respectively.
+
+    - Query.add_column() can accept FromClause objects
+      in the same manner as session.query() can.
+      
 - postgres
     - Calling alias.execute() in conjunction with
       server_side_cursors won't raise AttributeError.
index 82a698fd07e1e7be56332e8b706ce8c0c9c721e7..d8887631282f5da5e31857b4d0f4c1f64b3be184 100644 (file)
@@ -562,8 +562,11 @@ class Query(object):
         """Add a SQL ColumnElement to the list of result columns to be returned."""
 
         self._entities = list(self._entities)
-        c = _ColumnEntity(self, column)
-        self.__setup_aliasizers([c])
+        l = len(self._entities)
+        _ColumnEntity(self, column)
+        # _ColumnEntity may add many entities if the
+        # given arg is a FROM clause
+        self.__setup_aliasizers(self._entities[l:])
 
     def options(self, *args):
         """Return a new Query object, applying the given list of
@@ -930,6 +933,8 @@ class Query(object):
         if isinstance(from_obj, (tuple, list)):
             util.warn_deprecated("select_from() now accepts a single Selectable as its argument, which replaces any existing FROM criterion.")
             from_obj = from_obj[-1]
+        if not isinstance(from_obj, expression.FromClause):
+            raise sa_exc.ArgumentError("select_from() accepts FromClause objects only.")
         self.__set_select_from(from_obj)
 
     def __getitem__(self, item):
@@ -1013,6 +1018,10 @@ class Query(object):
         """
         if isinstance(statement, basestring):
             statement = sql.text(statement)
+
+        if not isinstance(statement, (expression._TextClause, expression._SelectBaseMixin)):
+            raise sa_exc.ArgumentError("from_statement accepts text(), select(), and union() objects only.")
+        
         self._statement = statement
 
     def first(self):
index 4121ac0e1913fe84297239f1d1c7777e63485f17..34372b83ab34d35d1f0d943e576ffd00fa479374 100644 (file)
@@ -236,6 +236,23 @@ class InvalidGenerationsTest(QueryTest):
         # this is fine, however
         q.from_self()
     
+    def test_invalid_select_from(self):
+        s = create_session()
+        q = s.query(User)
+        self.assertRaises(sa_exc.ArgumentError, q.select_from, User.id==5)
+        self.assertRaises(sa_exc.ArgumentError, q.select_from, User.id)
+
+    def test_invalid_from_statement(self):
+        s = create_session()
+        q = s.query(User)
+        self.assertRaises(sa_exc.ArgumentError, q.from_statement, User.id==5)
+        self.assertRaises(sa_exc.ArgumentError, q.from_statement, users.join(addresses))
+    
+    def test_invalid_column(self):
+        s = create_session()
+        q = s.query(User)
+        self.assertRaises(sa_exc.InvalidRequestError, q.add_column, object())
+        
     def test_mapper_zero(self):
         s = create_session()
         
@@ -1783,6 +1800,16 @@ class MixedEntitiesTest(QueryTest):
 
         self.assertRaises(sa_exc.InvalidRequestError, sess.query(User).add_column, object())
     
+    def test_add_multi_columns(self):
+        """test that add_column accepts a FROM clause."""
+        
+        sess = create_session()
+        
+        eq_(
+            sess.query(User.id).add_column(users).all(),
+            [(7, 7, u'jack'), (8, 8, u'ed'), (9, 9, u'fred'), (10, 10, u'chuck')]
+        )
+        
     def test_multi_columns_2(self):
         """test aliased/nonalised joins with the usage of add_column()"""
         sess = create_session()