]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Query() can be passed a "composite" attribute
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Dec 2008 16:50:49 +0000 (16:50 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Dec 2008 16:50:49 +0000 (16:50 +0000)
as a column expression and it will be expanded.
Somewhat related to [ticket:1253].
- Query() is a little more robust when passed
various column expressions such as strings,
clauselists, text() constructs (which may mean
it just raises an error more nicely).
- select() can accept a ClauseList as a column
in the same way as a Table or other selectable
and the interior expressions will be used as
column elements. [ticket:1253]
- removed erroneous FooTest from test/orm/query

-This line, and those below, will be ignored--

M    test/orm/query.py
M    test/orm/mapper.py
M    test/sql/select.py
M    lib/sqlalchemy/orm/query.py
M    lib/sqlalchemy/sql/expression.py
M    CHANGES

CHANGES
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/expression.py
test/orm/mapper.py
test/orm/query.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index 03c7debccf465844ef85966cf0b1829249d18bb6..9a16edfbc64905438bc54018b58e4738a7f13a98 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -50,6 +50,15 @@ CHANGES
       that the given argument is a FromClause,
       or Text/Select/Union, respectively.
 
+    - Query() can be passed a "composite" attribute
+      as a column expression and it will be expanded.
+      Somewhat related to [ticket:1253].
+      
+    - Query() is a little more robust when passed
+      various column expressions such as strings,
+      clauselists, text() constructs (which may mean
+      it just raises an error more nicely).
+      
     - first() works as expected with Query.from_statement().
     
     - Fixed bug introduced in 0.5rc4 involving eager 
@@ -146,6 +155,11 @@ CHANGES
       also would be a little misleading compared to
       values().
 
+    - select() can accept a ClauseList as a column
+      in the same way as a Table or other selectable
+      and the interior expressions will be used as
+      column elements. [ticket:1253]
+      
     - the "passive" flag on session.is_modified()
       is correctly propagated to the attribute manager.
 
index 9b491896d689892dec5b14849eab5d081dbb91b1..da33eac41e5103c85c6877561e6bb9353c627fd2 100644 (file)
@@ -1633,7 +1633,7 @@ class _QueryEntity(object):
     def __new__(cls, *args, **kwargs):
         if cls is _QueryEntity:
             entity = args[1]
-            if _is_mapped_class(entity):
+            if not isinstance(entity, basestring) and _is_mapped_class(entity):
                 cls = _MapperEntity
             else:
                 cls = _ColumnEntity
@@ -1785,27 +1785,32 @@ class _ColumnEntity(_QueryEntity):
     """Column/expression based entity."""
 
     def __init__(self, query, column):
-        if isinstance(column, expression.FromClause) and not isinstance(column, expression.ColumnElement):
-            for c in column.c:
-                _ColumnEntity(query, c)
-            return
-
-        query._entities.append(self)
-
         if isinstance(column, basestring):
             column = sql.literal_column(column)
             self._result_label = column.name
         elif isinstance(column, (attributes.QueryableAttribute, mapper.Mapper._CompileOnAttr)):
             self._result_label = column.impl.key
             column = column.__clause_element__()
-        elif not isinstance(column, sql.ColumnElement):
-            raise sa_exc.InvalidRequestError("Invalid column expression '%r'" % column)
         else:
             self._result_label = getattr(column, 'key', None)
+        
+        if not isinstance(column, expression.ColumnElement) and hasattr(column, '_select_iterable'):
+            for c in column._select_iterable:
+                if c is column:
+                    break
+                _ColumnEntity(query, c)
+            
+            if c is not column:
+                return
+
+        if not isinstance(column, sql.ColumnElement):
+            raise sa_exc.InvalidRequestError("Invalid column expression '%r'" % column)
 
         if not hasattr(column, '_label'):
             column = column.label(None)
 
+        query._entities.append(self)
+
         self.column = column
         self.froms = set()
 
index aa0b7228c4e94ba64c3f2ecb30e5cf3e772e10e7..0f7f62e74cc6506cf10a7ca9c8cfe1b1e2f79fcd 100644 (file)
@@ -2099,6 +2099,10 @@ class ClauseList(ClauseElement):
     def __len__(self):
         return len(self.clauses)
 
+    @property
+    def _select_iterable(self):
+        return iter(self)
+
     def append(self, clause):
         # TODO: not sure if i like the 'group_contents' flag.  need to
         # define the difference between a ClauseList of ClauseLists,
@@ -2148,6 +2152,10 @@ class BooleanClauseList(ClauseList, ColumnElement):
         super(BooleanClauseList, self).__init__(*clauses, **kwargs)
         self.type = sqltypes.to_instance(kwargs.get('type_', sqltypes.Boolean))
 
+    @property
+    def _select_iterable(self):
+        return (self, )
+
 
 class _CalculatedClause(ColumnElement):
     """Describe a calculated SQL expression that has a type, like ``CASE``.
index dfa2989d7f9045a31ce142cd76078920a08902bc..6fa532043d9a922255f49d0aae0a56888f77663e 100644 (file)
@@ -1676,6 +1676,9 @@ class CompositeTypesTest(_base.MappedTest):
 
         eq_(sess.query(Edge).filter(Edge.start==None).all(), [])
 
+        # query by columns
+        eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, 19, 5)])
+
     @testing.resolve_artifact_names
     def test_pk(self):
         """Using a composite type as a primary key"""
index 0a70ddaaac7dfdb3c89ac11d1d82e4c3f9f1c920..c617d860a9f5871636f66726f0d0cfe0f1ecb111 100644 (file)
@@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests()
 import operator
 from sqlalchemy import *
 from sqlalchemy import exc as sa_exc, util
-from sqlalchemy.sql import compiler
+from sqlalchemy.sql import compiler, table, column
 from sqlalchemy.engine import default
 from sqlalchemy.orm import *
 from sqlalchemy.orm import attributes
@@ -574,29 +574,6 @@ class SliceTest(QueryTest):
         ])
 
 
-class FooTest(FixtureTest):
-    keep_data = True
-        
-    def test_filter_by(self):
-        clear_mappers()
-        sess = create_session(bind=testing.db)
-        from sqlalchemy.ext.declarative import declarative_base
-        Base = declarative_base(bind=testing.db)
-        class User(Base, _base.ComparableEntity):
-            __table__ = users
-        
-        class Address(Base, _base.ComparableEntity):
-            __table__ = addresses
-
-        compile_mappers()
-#        Address.user = relation(User, primaryjoin="User.id==Address.user_id")
-        Address.user = relation(User, primaryjoin=User.id==Address.user_id)
-#        Address.user = relation(User, primaryjoin=users.c.id==addresses.c.user_id)
-        compile_mappers()
-#        Address.user.property.primaryjoin = User.id==Address.user_id
-        user = sess.query(User).get(8)
-        print sess.query(Address).filter_by(user=user).all()
-        assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter_by(user=user).all()
     
 class FilterTest(QueryTest):
     def test_basic(self):
@@ -906,6 +883,11 @@ class TextTest(QueryTest):
     def test_binds(self):
         assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (:id1, :id2)").params(id1=8, id2=9).all()
 
+    def test_as_column(self):
+        s = create_session()
+        self.assertRaises(sa_exc.InvalidRequestError, s.query, User.id, text("users.name"))
+
+        eq_(s.query(User.id, "name").order_by(User.id).all(), [(7, u'jack'), (8, u'ed'), (9, u'fred'), (10, u'chuck')])
 
 class ParentTest(QueryTest):
     def test_o2m(self):
index 352eed42cb50b61d83574658775a22cc52939d9e..12c8524cf8f127508ddde650c3dbd9488b9155f4 100644 (file)
@@ -3,6 +3,7 @@ import datetime, re, operator
 from sqlalchemy import *
 from sqlalchemy import exc, sql, util
 from sqlalchemy.sql import table, column, label, compiler
+from sqlalchemy.sql.expression import ClauseList
 from sqlalchemy.engine import default
 from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql
 from testlib import *
@@ -125,6 +126,11 @@ sq2.sq_myothertable_otherid, sq2.sq_myothertable_othername FROM \
 sq.mytable_description AS sq_mytable_description, sq.myothertable_otherid AS sq_myothertable_otherid, \
 sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") AS sq) AS sq2")
 
+    def test_select_from_clauselist(self):
+        self.assert_compile(
+            select([ClauseList(column('a'), column('b'))]).select_from('sometable'), 
+            'SELECT a, b FROM sometable'
+        )
     def test_nested_uselabels(self):
         """test nested anonymous label generation.  this
         essentially tests the ANONYMOUS_LABEL regex.