From: Mike Bayer Date: Thu, 18 Dec 2008 16:50:49 +0000 (+0000) Subject: - Query() can be passed a "composite" attribute X-Git-Tag: rel_0_5_0~94 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=b3337893365a720646e073806b9a379ad839a970;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Query() can be passed a "composite" attribute as a column expression and it will be expanded. Somewhat related to [ticket:1253]. - Query() is a little more robust when passed various column expressions such as strings, clauselists, text() constructs (which may mean it just raises an error more nicely). - select() can accept a ClauseList as a column in the same way as a Table or other selectable and the interior expressions will be used as column elements. [ticket:1253] - removed erroneous FooTest from test/orm/query -This line, and those below, will be ignored-- M test/orm/query.py M test/orm/mapper.py M test/sql/select.py M lib/sqlalchemy/orm/query.py M lib/sqlalchemy/sql/expression.py M CHANGES --- diff --git a/CHANGES b/CHANGES index 03c7debccf..9a16edfbc6 100644 --- a/CHANGES +++ b/CHANGES @@ -50,6 +50,15 @@ CHANGES that the given argument is a FromClause, or Text/Select/Union, respectively. + - Query() can be passed a "composite" attribute + as a column expression and it will be expanded. + Somewhat related to [ticket:1253]. + + - Query() is a little more robust when passed + various column expressions such as strings, + clauselists, text() constructs (which may mean + it just raises an error more nicely). + - first() works as expected with Query.from_statement(). - Fixed bug introduced in 0.5rc4 involving eager @@ -146,6 +155,11 @@ CHANGES also would be a little misleading compared to values(). + - select() can accept a ClauseList as a column + in the same way as a Table or other selectable + and the interior expressions will be used as + column elements. [ticket:1253] + - the "passive" flag on session.is_modified() is correctly propagated to the attribute manager. diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 9b491896d6..da33eac41e 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1633,7 +1633,7 @@ class _QueryEntity(object): def __new__(cls, *args, **kwargs): if cls is _QueryEntity: entity = args[1] - if _is_mapped_class(entity): + if not isinstance(entity, basestring) and _is_mapped_class(entity): cls = _MapperEntity else: cls = _ColumnEntity @@ -1785,27 +1785,32 @@ class _ColumnEntity(_QueryEntity): """Column/expression based entity.""" def __init__(self, query, column): - if isinstance(column, expression.FromClause) and not isinstance(column, expression.ColumnElement): - for c in column.c: - _ColumnEntity(query, c) - return - - query._entities.append(self) - if isinstance(column, basestring): column = sql.literal_column(column) self._result_label = column.name elif isinstance(column, (attributes.QueryableAttribute, mapper.Mapper._CompileOnAttr)): self._result_label = column.impl.key column = column.__clause_element__() - elif not isinstance(column, sql.ColumnElement): - raise sa_exc.InvalidRequestError("Invalid column expression '%r'" % column) else: self._result_label = getattr(column, 'key', None) + + if not isinstance(column, expression.ColumnElement) and hasattr(column, '_select_iterable'): + for c in column._select_iterable: + if c is column: + break + _ColumnEntity(query, c) + + if c is not column: + return + + if not isinstance(column, sql.ColumnElement): + raise sa_exc.InvalidRequestError("Invalid column expression '%r'" % column) if not hasattr(column, '_label'): column = column.label(None) + query._entities.append(self) + self.column = column self.froms = set() diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index aa0b7228c4..0f7f62e74c 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -2099,6 +2099,10 @@ class ClauseList(ClauseElement): def __len__(self): return len(self.clauses) + @property + def _select_iterable(self): + return iter(self) + def append(self, clause): # TODO: not sure if i like the 'group_contents' flag. need to # define the difference between a ClauseList of ClauseLists, @@ -2148,6 +2152,10 @@ class BooleanClauseList(ClauseList, ColumnElement): super(BooleanClauseList, self).__init__(*clauses, **kwargs) self.type = sqltypes.to_instance(kwargs.get('type_', sqltypes.Boolean)) + @property + def _select_iterable(self): + return (self, ) + class _CalculatedClause(ColumnElement): """Describe a calculated SQL expression that has a type, like ``CASE``. diff --git a/test/orm/mapper.py b/test/orm/mapper.py index dfa2989d7f..6fa532043d 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -1676,6 +1676,9 @@ class CompositeTypesTest(_base.MappedTest): eq_(sess.query(Edge).filter(Edge.start==None).all(), []) + # query by columns + eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, 19, 5)]) + @testing.resolve_artifact_names def test_pk(self): """Using a composite type as a primary key""" diff --git a/test/orm/query.py b/test/orm/query.py index 0a70ddaaac..c617d860a9 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests() import operator from sqlalchemy import * from sqlalchemy import exc as sa_exc, util -from sqlalchemy.sql import compiler +from sqlalchemy.sql import compiler, table, column from sqlalchemy.engine import default from sqlalchemy.orm import * from sqlalchemy.orm import attributes @@ -574,29 +574,6 @@ class SliceTest(QueryTest): ]) -class FooTest(FixtureTest): - keep_data = True - - def test_filter_by(self): - clear_mappers() - sess = create_session(bind=testing.db) - from sqlalchemy.ext.declarative import declarative_base - Base = declarative_base(bind=testing.db) - class User(Base, _base.ComparableEntity): - __table__ = users - - class Address(Base, _base.ComparableEntity): - __table__ = addresses - - compile_mappers() -# Address.user = relation(User, primaryjoin="User.id==Address.user_id") - Address.user = relation(User, primaryjoin=User.id==Address.user_id) -# Address.user = relation(User, primaryjoin=users.c.id==addresses.c.user_id) - compile_mappers() -# Address.user.property.primaryjoin = User.id==Address.user_id - user = sess.query(User).get(8) - print sess.query(Address).filter_by(user=user).all() - assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter_by(user=user).all() class FilterTest(QueryTest): def test_basic(self): @@ -906,6 +883,11 @@ class TextTest(QueryTest): def test_binds(self): assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (:id1, :id2)").params(id1=8, id2=9).all() + def test_as_column(self): + s = create_session() + self.assertRaises(sa_exc.InvalidRequestError, s.query, User.id, text("users.name")) + + eq_(s.query(User.id, "name").order_by(User.id).all(), [(7, u'jack'), (8, u'ed'), (9, u'fred'), (10, u'chuck')]) class ParentTest(QueryTest): def test_o2m(self): diff --git a/test/sql/select.py b/test/sql/select.py index 352eed42cb..12c8524cf8 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -3,6 +3,7 @@ import datetime, re, operator from sqlalchemy import * from sqlalchemy import exc, sql, util from sqlalchemy.sql import table, column, label, compiler +from sqlalchemy.sql.expression import ClauseList from sqlalchemy.engine import default from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql from testlib import * @@ -125,6 +126,11 @@ sq2.sq_myothertable_otherid, sq2.sq_myothertable_othername FROM \ sq.mytable_description AS sq_mytable_description, sq.myothertable_otherid AS sq_myothertable_otherid, \ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") AS sq) AS sq2") + def test_select_from_clauselist(self): + self.assert_compile( + select([ClauseList(column('a'), column('b'))]).select_from('sometable'), + 'SELECT a, b FROM sometable' + ) def test_nested_uselabels(self): """test nested anonymous label generation. this essentially tests the ANONYMOUS_LABEL regex.