From cc9292be84f87ea56be7732ef0062219cf37335a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 17 Jul 2007 04:54:30 +0000 Subject: [PATCH] - got in_() working, enhanced sql.py treatment of Comparator so comparators can be used in any SQL expression (i.e. order bys, desc(), etc.) - adding various tests for new clause generation --- lib/sqlalchemy/orm/attributes.py | 4 ++-- lib/sqlalchemy/orm/properties.py | 4 ++-- lib/sqlalchemy/sql.py | 22 ++++++++++++---------- test/orm/eager_relations.py | 7 ++++--- test/orm/inheritance/polymorph.py | 9 ++++----- test/orm/query.py | 7 +++++-- 6 files changed, 29 insertions(+), 24 deletions(-) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index ad9675f029..a3db154c65 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -81,8 +81,8 @@ class InstrumentedAttribute(sql.Comparator): return self return self.get(obj) - def compare_self(self): - return self.comparator.compare_self() + def clause_element(self): + return self.comparator.clause_element() def operate(self, op, other): return op(self.comparator, other) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 8a57b4a83e..9ab3cce229 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -60,7 +60,7 @@ class ColumnProperty(StrategizedProperty): return value class ColumnComparator(PropComparator): - def compare_self(self): + def clause_element(self): return self.prop.columns[0] def operate(self, op, other): @@ -69,7 +69,7 @@ class ColumnProperty(StrategizedProperty): def reverse_operate(self, op, other): col = self.prop.columns[0] return op(col._bind_param(other), col) - + ColumnProperty.logger = logging.class_logger(ColumnProperty) diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index b6a843685c..e044729e08 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -781,7 +781,9 @@ def _is_literal(element): return not isinstance(element, ClauseElement) def _literal_as_text(element): - if _is_literal(element): + if isinstance(element, Comparator): + return element.clause_element() + elif _is_literal(element): return _TextClause(unicode(element)) else: return element @@ -1144,7 +1146,7 @@ class Comparator(object): between_op = staticmethod(between_op) def in_op(a, b): - return a.in_(b) + return a.in_(*b) in_op = staticmethod(in_op) def startswith_op(a, b): @@ -1155,7 +1157,7 @@ class Comparator(object): return a.endswith(b) endswith_op = staticmethod(endswith_op) - def compare_self(self): + def clause_element(self): raise NotImplementedError() def operate(self, op, other): @@ -1233,19 +1235,19 @@ class _CompareMixin(Comparator): def __compare(self, operator, obj, negate=None): if obj is None or isinstance(obj, _Null): if operator == '=': - return _BinaryExpression(self.compare_self(), null(), 'IS', negate='IS NOT') + return _BinaryExpression(self.clause_element(), null(), 'IS', negate='IS NOT') elif operator == '!=': - return _BinaryExpression(self.compare_self(), null(), 'IS NOT', negate='IS') + return _BinaryExpression(self.clause_element(), null(), 'IS NOT', negate='IS') else: raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") else: obj = self._check_literal(obj) - return _BinaryExpression(self.compare_self(), obj, operator, type=sqltypes.Boolean, negate=negate) + return _BinaryExpression(self.clause_element(), obj, operator, type=sqltypes.Boolean, negate=negate) def __operate(self, operator, obj): obj = self._check_literal(obj) - return _BinaryExpression(self.compare_self(), obj, operator, type=self._compare_type(obj)) + return _BinaryExpression(self.clause_element(), obj, operator, type=self._compare_type(obj)) operators = { operator.add : (__operate, '+'), @@ -1341,13 +1343,13 @@ class _CompareMixin(Comparator): def _check_literal(self, other): if isinstance(other, Comparator): - return other.compare_self() + return other.clause_element() elif _is_literal(other): return self._bind_param(other) else: return other - def compare_self(self): + def clause_element(self): """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions.""" return self @@ -2456,7 +2458,7 @@ class _Label(ColumnElement): _label = property(lambda s: s.name) orig_set = property(lambda s:s.obj.orig_set) - def compare_self(self): + def clause_element(self): return self.obj def _copy_internals(self): diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py index 37b5ecdf7e..90ae3ba53e 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/eager_relations.py @@ -20,7 +20,7 @@ class EagerTest(QueryTest): sess = create_session() q = sess.query(User) - assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(users.c.id == 7).all() + assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(User.id==7).all() assert fixtures.user_address_result == q.all() def test_no_orphan(self): @@ -375,6 +375,7 @@ class EagerTest(QueryTest): 'user':relation(User, lazy=False) }) mapper(User, users) + mapper(Item, items) q = create_session().query(Order) assert [ @@ -382,7 +383,7 @@ class EagerTest(QueryTest): Order(id=4, user=User(id=9)) ] == q.all() - q = q.select_from(s.join(order_items).join(items)).filter(~items.c.id.in_(1, 2, 5)) + q = q.select_from(s.join(order_items).join(items)).filter(~Item.id.in_(1, 2, 5)) assert [ Order(id=3, user=User(id=7)), ] == q.all() @@ -394,7 +395,7 @@ class EagerTest(QueryTest): addresses = relation(mapper(Address, addresses), lazy=False) )) q = create_session().query(User) - l = q.filter(addresses.c.email_address == 'ed@lala.com').filter(addresses.c.user_id==users.c.id) + l = q.filter(addresses.c.email_address == 'ed@lala.com').filter(Address.user_id==User.id) assert fixtures.user_address_result[1:2] == l.all() if __name__ == '__main__': diff --git a/test/orm/inheritance/polymorph.py b/test/orm/inheritance/polymorph.py index 0fadfa1950..d7900610ff 100644 --- a/test/orm/inheritance/polymorph.py +++ b/test/orm/inheritance/polymorph.py @@ -295,18 +295,17 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co assert sets.Set([(e.get_name(), e.status) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('wally', 'CGG'), ('jsmith', 'ABA')]) print "\n" - # test selecting from the query, using the base mapped table (people) as the selection criterion. # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join" - dilbert = session.query(Person).selectfirst(people.c.name=='dilbert') - dilbert2 = session.query(Engineer).selectfirst(people.c.name=='dilbert') + dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first() + dilbert2 = session.query(Engineer).filter(getattr(Person, person_attribute_name)=='dilbert').first() assert dilbert is dilbert2 # test selecting from the query, joining against an alias of the base "people" table. test that # the "palias" alias does *not* get sucked up into the "person_join" conversion. palias = people.alias("palias") - session.query(Person).selectfirst((palias.c.name=='dilbert') & (palias.c.person_id==people.c.person_id)) - dilbert2 = session.query(Engineer).selectfirst((palias.c.name=='dilbert') & (palias.c.person_id==people.c.person_id)) + session.query(Person).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first() + dilbert2 = session.query(Engineer).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first() assert dilbert is dilbert2 session.query(Person).selectfirst((engineers.c.engineer_name=="engineer1") & (engineers.c.person_id==people.c.person_id)) diff --git a/test/orm/query.py b/test/orm/query.py index 75885fb8ce..4688d593fd 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -152,6 +152,9 @@ class OperatorTest(QueryTest): "\n'" + compiled + "'\n does not match\n'" + fwd_sql + "'\n or\n'" + rev_sql + "'") + def test_in(self): + self._test(User.id.in_('a', 'b'), "users.id IN (:users_id, :users_id_1)") + class CompileTest(QueryTest): def test_deferred(self): session = create_session() @@ -469,11 +472,11 @@ class InstancesTest(QueryTest): ] q = sess.query(User) - q = q.group_by([c for c in users.c]).order_by(User.c.id).outerjoin('addresses').add_column(func.count(addresses.c.id).label('count')) + q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses').add_column(func.count(addresses.c.id).label('count')) l = q.all() assert l == expected - s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(users.c.id) + s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(User.id) q = sess.query(User) l = q.add_column("count").from_statement(s).all() assert l == expected -- 2.47.3