]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
corrected literals_as_binds to recognize sql.Operators objects for [ticket:675]
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Jul 2007 19:53:16 +0000 (19:53 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Jul 2007 19:53:16 +0000 (19:53 +0000)
lib/sqlalchemy/sql.py
test/orm/query.py

index 3c23d67398370000cfc0261eccde9896e11997d1..04de9d42da6c56f406a6840e5976f739ba7c2abb 100644 (file)
@@ -371,6 +371,7 @@ def between(ctest, cleft, cright):
     provides similar functionality.
     """
 
+    ctest = _literal_as_binds(ctest)
     return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operator.and_, group=False), ColumnOperators.between_op)
 
 
@@ -752,7 +753,9 @@ def _literal_as_text(element):
         return element
 
 def _literal_as_binds(element, name='literal', type_=None):
-    if _is_literal(element):
+    if isinstance(element, Operators):
+        return element.clause_element()
+    elif _is_literal(element):
         if element is None:
             return null()
         else:
index 9d516f3780b51237a2f5b809f6330c30bac7531f..9f14d9a25a6ac0ffc2c6bd6d771c5b3a57306760 100644 (file)
@@ -213,6 +213,18 @@ class OperatorTest(QueryTest):
     def test_in(self):
          self._test(User.id.in_('a', 'b'), "users.id IN (:users_id, :users_id_1)")
     
+    def test_clauses(self):
+        for (expr, compare) in (
+            (func.max(User.id), "max(users.id)"),
+            (desc(User.id), "users.id DESC"),
+            (between(5, User.id, Address.id), ":literal BETWEEN users.id AND addresses.id"),
+            # this one would require adding compile() to InstrumentedScalarAttribute.  do we want this ?
+            #(User.id, "users.id")
+        ):
+            c = expr.compile(dialect=ansisql.ANSIDialect())
+            assert str(c) == compare, "%s != %s" % (str(c), compare)
+            
+            
 class CompileTest(QueryTest):
     def test_deferred(self):
         session = create_session()
@@ -642,7 +654,7 @@ class InstancesTest(QueryTest):
         
         for aliased in (False, True):
             q = sess.query(User)
-            q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses', aliased=aliased).add_column(func.count(addresses.c.id).label('count'))
+            q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses', aliased=aliased).add_column(func.count(Address.id).label('count'))
             l = q.all()
             assert l == expected