]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- got in_() working, enhanced sql.py treatment of Comparator so comparators can be...
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 17 Jul 2007 04:54:30 +0000 (04:54 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 17 Jul 2007 04:54:30 +0000 (04:54 +0000)
- adding various tests for new clause generation

lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/sql.py
test/orm/eager_relations.py
test/orm/inheritance/polymorph.py
test/orm/query.py

index ad9675f029be44d493fda6c4423746a2b9ab5cc4..a3db154c65fa8b5bede618ff740511363e9614c9 100644 (file)
@@ -81,8 +81,8 @@ class InstrumentedAttribute(sql.Comparator):
             return self
         return self.get(obj)
 
-    def compare_self(self):
-        return self.comparator.compare_self()
+    def clause_element(self):
+        return self.comparator.clause_element()
         
     def operate(self, op, other):
         return op(self.comparator, other)
index 8a57b4a83ead56fa8ffc2e1acd3e39e9135695fe..9ab3cce2296196c79970f6586721e19507b4358e 100644 (file)
@@ -60,7 +60,7 @@ class ColumnProperty(StrategizedProperty):
         return value
 
     class ColumnComparator(PropComparator):
-        def compare_self(self):
+        def clause_element(self):
             return self.prop.columns[0]
             
         def operate(self, op, other):
@@ -69,7 +69,7 @@ class ColumnProperty(StrategizedProperty):
         def reverse_operate(self, op, other):
             col = self.prop.columns[0]
             return op(col._bind_param(other), col)
-            
+
             
 ColumnProperty.logger = logging.class_logger(ColumnProperty)
 
index b6a843685caed6fa8ac46622f17f1f19af845b15..e044729e0810754f48ee6b2bf31eddc41c666fc2 100644 (file)
@@ -781,7 +781,9 @@ def _is_literal(element):
     return not isinstance(element, ClauseElement)
 
 def _literal_as_text(element):
-    if _is_literal(element):
+    if isinstance(element, Comparator):
+        return element.clause_element()
+    elif _is_literal(element):
         return _TextClause(unicode(element))
     else:
         return element
@@ -1144,7 +1146,7 @@ class Comparator(object):
     between_op = staticmethod(between_op)
     
     def in_op(a, b):
-        return a.in_(b)
+        return a.in_(*b)
     in_op = staticmethod(in_op)
     
     def startswith_op(a, b):
@@ -1155,7 +1157,7 @@ class Comparator(object):
         return a.endswith(b)
     endswith_op = staticmethod(endswith_op)
     
-    def compare_self(self):
+    def clause_element(self):
         raise NotImplementedError()
         
     def operate(self, op, other):
@@ -1233,19 +1235,19 @@ class _CompareMixin(Comparator):
     def __compare(self, operator, obj, negate=None):
         if obj is None or isinstance(obj, _Null):
             if operator == '=':
-                return _BinaryExpression(self.compare_self(), null(), 'IS', negate='IS NOT')
+                return _BinaryExpression(self.clause_element(), null(), 'IS', negate='IS NOT')
             elif operator == '!=':
-                return _BinaryExpression(self.compare_self(), null(), 'IS NOT', negate='IS')
+                return _BinaryExpression(self.clause_element(), null(), 'IS NOT', negate='IS')
             else:
                 raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
         else:
             obj = self._check_literal(obj)
 
-        return _BinaryExpression(self.compare_self(), obj, operator, type=sqltypes.Boolean, negate=negate)
+        return _BinaryExpression(self.clause_element(), obj, operator, type=sqltypes.Boolean, negate=negate)
 
     def __operate(self, operator, obj):
         obj = self._check_literal(obj)
-        return _BinaryExpression(self.compare_self(), obj, operator, type=self._compare_type(obj))
+        return _BinaryExpression(self.clause_element(), obj, operator, type=self._compare_type(obj))
 
     operators = {
         operator.add : (__operate, '+'),
@@ -1341,13 +1343,13 @@ class _CompareMixin(Comparator):
 
     def _check_literal(self, other):
         if isinstance(other, Comparator):
-            return other.compare_self()
+            return other.clause_element()
         elif _is_literal(other):
             return self._bind_param(other)
         else:
             return other
     
-    def compare_self(self):
+    def clause_element(self):
         """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions."""
 
         return self
@@ -2456,7 +2458,7 @@ class _Label(ColumnElement):
     _label = property(lambda s: s.name)
     orig_set = property(lambda s:s.obj.orig_set)
 
-    def compare_self(self):
+    def clause_element(self):
         return self.obj
     
     def _copy_internals(self):
index 37b5ecdf7e5984a370300ef29ca9ef222564a080..90ae3ba53e3295d5828bb493f039c316904d4235 100644 (file)
@@ -20,7 +20,7 @@ class EagerTest(QueryTest):
         sess = create_session()
         q = sess.query(User)
 
-        assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(users.c.id == 7).all()
+        assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(User.id==7).all()
         assert fixtures.user_address_result == q.all()
 
     def test_no_orphan(self):
@@ -375,6 +375,7 @@ class EagerTest(QueryTest):
             'user':relation(User, lazy=False)
         })
         mapper(User, users)
+        mapper(Item, items)
 
         q = create_session().query(Order)
         assert [
@@ -382,7 +383,7 @@ class EagerTest(QueryTest):
             Order(id=4, user=User(id=9))
         ] == q.all()
         
-        q = q.select_from(s.join(order_items).join(items)).filter(~items.c.id.in_(1, 2, 5))
+        q = q.select_from(s.join(order_items).join(items)).filter(~Item.id.in_(1, 2, 5))
         assert [
             Order(id=3, user=User(id=7)),
         ] == q.all()
@@ -394,7 +395,7 @@ class EagerTest(QueryTest):
             addresses = relation(mapper(Address, addresses), lazy=False)
         ))
         q = create_session().query(User)
-        l = q.filter(addresses.c.email_address == 'ed@lala.com').filter(addresses.c.user_id==users.c.id)
+        l = q.filter(addresses.c.email_address == 'ed@lala.com').filter(Address.user_id==User.id)
         assert fixtures.user_address_result[1:2] == l.all()
 
 if __name__ == '__main__':
index 0fadfa1950dbd8a24b5146be69528d88775f7969..d7900610ffeca39b031bc18a6d10a4579af997a2 100644 (file)
@@ -295,18 +295,17 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
             assert sets.Set([(e.get_name(), e.status) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('wally', 'CGG'), ('jsmith', 'ABA')])
         print "\n"
 
-    
         # test selecting from the query, using the base mapped table (people) as the selection criterion.
         # in the case of the polymorphic Person query, the "people" selectable should be adapted to be "person_join"
-        dilbert = session.query(Person).selectfirst(people.c.name=='dilbert')
-        dilbert2 = session.query(Engineer).selectfirst(people.c.name=='dilbert')
+        dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
+        dilbert2 = session.query(Engineer).filter(getattr(Person, person_attribute_name)=='dilbert').first()
         assert dilbert is dilbert2
 
         # test selecting from the query, joining against an alias of the base "people" table.  test that
         # the "palias" alias does *not* get sucked up into the "person_join" conversion.
         palias = people.alias("palias")
-        session.query(Person).selectfirst((palias.c.name=='dilbert') & (palias.c.person_id==people.c.person_id))
-        dilbert2 = session.query(Engineer).selectfirst((palias.c.name=='dilbert') & (palias.c.person_id==people.c.person_id))
+        session.query(Person).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first()
+        dilbert2 = session.query(Engineer).filter((palias.c.name=='dilbert') & (palias.c.person_id==Person.person_id)).first()
         assert dilbert is dilbert2
 
         session.query(Person).selectfirst((engineers.c.engineer_name=="engineer1") & (engineers.c.person_id==people.c.person_id))
index 75885fb8ce09a8e6314c953beb2d2fa4791884af..4688d593fd0f212050ecb024de5a5e2988245ff5 100644 (file)
@@ -152,6 +152,9 @@ class OperatorTest(QueryTest):
                              "\n'" + compiled + "'\n does not match\n'" +
                              fwd_sql + "'\n or\n'" + rev_sql + "'")
     
+    def test_in(self):
+         self._test(User.id.in_('a', 'b'), "users.id IN (:users_id, :users_id_1)")
+    
 class CompileTest(QueryTest):
     def test_deferred(self):
         session = create_session()
@@ -469,11 +472,11 @@ class InstancesTest(QueryTest):
             ]
             
         q = sess.query(User)
-        q = q.group_by([c for c in users.c]).order_by(User.c.id).outerjoin('addresses').add_column(func.count(addresses.c.id).label('count'))
+        q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses').add_column(func.count(addresses.c.id).label('count'))
         l = q.all()
         assert l == expected
 
-        s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(users.c.id)
+        s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(User.id)
         q = sess.query(User)
         l = q.add_column("count").from_statement(s).all()
         assert l == expected