From cdcd74cb390507fc4e04bd61f77a2a2d2baa9614 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 27 Nov 2005 05:31:22 +0000 Subject: [PATCH] some fixes to IN clauses, literal text clauses displaying text/numeric properly including longs --- lib/sqlalchemy/ansisql.py | 10 +++++----- lib/sqlalchemy/sql.py | 26 ++++++++++++++++++-------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 6a618e6cd4..55550bfa88 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -151,10 +151,7 @@ class ANSICompiler(sql.Compiled): if compound.operator is None: sep = " " else: - if compound.spaces: - sep = compound.operator - else: - sep = " " + compound.operator + " " + sep = " " + compound.operator + " " if compound.parens: self.strings[compound] = "(" + string.join([self.get_str(c) for c in compound.clauses], sep) + ")" @@ -162,7 +159,10 @@ class ANSICompiler(sql.Compiled): self.strings[compound] = string.join([self.get_str(c) for c in compound.clauses], sep) def visit_clauselist(self, list): - self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ') + if list.parens: + self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ', ') + ")" + else: + self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ') def visit_binary(self, binary): result = self.get_str(binary.left) diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 54b604930a..5b68ff1bdd 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -21,7 +21,7 @@ import sqlalchemy.schema as schema import sqlalchemy.util as util import sqlalchemy.types as types -import string +import string, re __ALL__ = ['textclause', 'select', 'join', 'and_', 'or_', 'not_', 'union', 'unionall', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'bindparam', 'sequence'] @@ -328,8 +328,11 @@ class CompareMixin(object): elif len(other) == 1 and not isinstance(other[0], Selectable): return self.__eq__(other[0]) elif _is_literal(other[0]): - return self._compare('IN', CompoundClause(',', spaces=False, parens=True, *other)) + return self._compare('IN', ClauseList(parens=True, *[TextClause(o, isliteral=True) for o in other])) else: + # assume *other is a list of selects. + # so put them in a UNION. if theres only one, you just get one SELECT + # statement out of it. return self._compare('IN', union(*other)) def startswith(self, other): @@ -421,12 +424,19 @@ class BindParamClause(ClauseElement): return self.type.convert_bind_param(value) class TextClause(ClauseElement): - """represents any plain text WHERE clause or full SQL statement""" + """represents literal text, including SQL fragments as well + as literal (non bind-param) values.""" - def __init__(self, text = "", engine=None): + def __init__(self, text = "", engine=None, isliteral=False): self.text = text self.parens = False self.engine = engine + if isliteral: + if isinstance(text, int) or isinstance(text, long): + self.text = str(text) + else: + text = re.sub(r"'", r"''", text) + self.text = "'" + text + "'" def accept_visitor(self, visitor): visitor.visit_textclause(self) def hash_key(self): @@ -447,8 +457,7 @@ class CompoundClause(ClauseElement): def __init__(self, operator, *clauses, **kwargs): self.operator = operator self.clauses = [] - self.parens = kwargs.pop('parens', False) - self.spaces = kwargs.pop('spaces', False) + self.parens = False for c in clauses: if c is None: continue self.append(c) @@ -459,7 +468,7 @@ class CompoundClause(ClauseElement): def append(self, clause): if _is_literal(clause): - clause = TextClause(repr(clause)) + clause = TextClause(str(clause)) elif isinstance(clause, CompoundClause): clause.parens = True self.clauses.append(clause) @@ -479,8 +488,9 @@ class CompoundClause(ClauseElement): return string.join([c.hash_key() for c in self.clauses], self.operator) class ClauseList(ClauseElement): - def __init__(self, *clauses): + def __init__(self, *clauses, **kwargs): self.clauses = clauses + self.parens = kwargs.get('parens', False) def accept_visitor(self, visitor): for c in self.clauses: -- 2.47.2