]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
some fixes to IN clauses, literal text clauses displaying text/numeric properly including
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 27 Nov 2005 05:31:22 +0000 (05:31 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 27 Nov 2005 05:31:22 +0000 (05:31 +0000)
longs

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/sql.py

index 6a618e6cd4114c79caecde297657c4734ff92c29..55550bfa88650415c7214732a9e5ae413b37f74c 100644 (file)
@@ -151,10 +151,7 @@ class ANSICompiler(sql.Compiled):
         if compound.operator is None:
             sep = " "
         else:
-            if compound.spaces:
-                sep = compound.operator
-            else:
-                sep = " " + compound.operator + " "
+            sep = " " + compound.operator + " "
         
         if compound.parens:
             self.strings[compound] = "(" + string.join([self.get_str(c) for c in compound.clauses], sep) + ")"
@@ -162,7 +159,10 @@ class ANSICompiler(sql.Compiled):
             self.strings[compound] = string.join([self.get_str(c) for c in compound.clauses], sep)
         
     def visit_clauselist(self, list):
-        self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ')
+        if list.parens:
+            self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ', ') + ")"
+        else:
+            self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ')
         
     def visit_binary(self, binary):
         result = self.get_str(binary.left)
index 54b604930a9d727a72d343c4666c0453a4658e74..5b68ff1bddd169a03a6d966c5763b7811c2d74c8 100644 (file)
@@ -21,7 +21,7 @@
 import sqlalchemy.schema as schema
 import sqlalchemy.util as util
 import sqlalchemy.types as types
-import string
+import string, re
 
 __ALL__ = ['textclause', 'select', 'join', 'and_', 'or_', 'not_', 'union', 'unionall', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'bindparam', 'sequence']
 
@@ -328,8 +328,11 @@ class CompareMixin(object):
         elif len(other) == 1 and not isinstance(other[0], Selectable):
             return self.__eq__(other[0])
         elif _is_literal(other[0]):
-            return self._compare('IN', CompoundClause(',', spaces=False, parens=True, *other))
+            return self._compare('IN', ClauseList(parens=True, *[TextClause(o, isliteral=True) for o in other]))
         else:
+            # assume *other is a list of selects.
+            # so put them in a UNION.  if theres only one, you just get one SELECT 
+            # statement out of it.
             return self._compare('IN', union(*other))
 
     def startswith(self, other):
@@ -421,12 +424,19 @@ class BindParamClause(ClauseElement):
         return self.type.convert_bind_param(value)
             
 class TextClause(ClauseElement):
-    """represents any plain text WHERE clause or full SQL statement"""
+    """represents literal text, including SQL fragments as well
+    as literal (non bind-param) values."""
     
-    def __init__(self, text = "", engine=None):
+    def __init__(self, text = "", engine=None, isliteral=False):
         self.text = text
         self.parens = False
         self.engine = engine
+        if isliteral:
+            if isinstance(text, int) or isinstance(text, long):
+                self.text = str(text)
+            else:
+                text = re.sub(r"'", r"''", text)
+                self.text = "'" + text + "'"
     def accept_visitor(self, visitor): 
         visitor.visit_textclause(self)
     def hash_key(self):
@@ -447,8 +457,7 @@ class CompoundClause(ClauseElement):
     def __init__(self, operator, *clauses, **kwargs):
         self.operator = operator
         self.clauses = []
-        self.parens = kwargs.pop('parens', False)
-        self.spaces = kwargs.pop('spaces', False)
+        self.parens = False
         for c in clauses:
             if c is None: continue
             self.append(c)
@@ -459,7 +468,7 @@ class CompoundClause(ClauseElement):
         
     def append(self, clause):
         if _is_literal(clause):
-            clause = TextClause(repr(clause))
+            clause = TextClause(str(clause))
         elif isinstance(clause, CompoundClause):
             clause.parens = True
         self.clauses.append(clause)
@@ -479,8 +488,9 @@ class CompoundClause(ClauseElement):
         return string.join([c.hash_key() for c in self.clauses], self.operator)
         
 class ClauseList(ClauseElement):
-    def __init__(self, *clauses):
+    def __init__(self, *clauses, **kwargs):
         self.clauses = clauses
+        self.parens = kwargs.get('parens', False)
         
     def accept_visitor(self, visitor):
         for c in self.clauses: