From: Mike Bayer Date: Sat, 5 Nov 2005 05:17:15 +0000 (+0000) Subject: (no commit message) X-Git-Tag: rel_0_1_0~375 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f0b12f6f47932232b271ca87e067d06f42504eec;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git --- diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 1209b2324e..7205a5e9fb 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -46,9 +46,8 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine): def dbapi(self): return object() - class ANSICompiler(sql.Compiled): - def __init__(self, engine, statement, bindparams, typemap=None, **kwargs): + def __init__(self, engine, statement, bindparams, typemap=None, paramstyle=None,**kwargs): sql.Compiled.__init__(self, engine, statement, bindparams) self.binds = {} self.froms = {} @@ -57,6 +56,39 @@ class ANSICompiler(sql.Compiled): self.typemap = typemap or {} self.isinsert = False + if paramstyle is None: + paramstyle = self.engine.dbapi().paramstyle + + if paramstyle == 'named': + self.bindtemplate = ':%s' + self.positional=False + elif paramstyle =='pyformat': + self.bindtemplate = "%%(%s)s" + self.positional=False + else: + # for positional, use pyformat until the end + self.bindtemplate = "%%(%s)s" + self.positional=True + self.paramstyle=paramstyle + + def after_compile(self): + if self.positional: + self.positiontup = [] + match = r'%\(([\w_]+)\)s' + params = re.finditer(match, self.strings[self.statement]) + for p in params: + self.positiontup.append(p.group(1)) + if self.paramstyle=='qmark': + self.strings[self.statement] = re.sub(match, '?', self.strings[self.statement]) + elif self.paramstyle=='format': + self.strings[self.statement] = re.sub(match, '%s', self.strings[self.statement]) + elif self.paramstyle=='numeric': + i = 0 + def getnum(x): + i += 1 + return i + self.strings[self.statement] = re.sub(match, getnum(s), self.strings[self.statement]) + def get_from_text(self, obj): return self.froms[obj] @@ -85,7 +117,10 @@ class ANSICompiler(sql.Compiled): for b in self.binds.values(): d.setdefault(b.key, b.typeprocess(b.value)) - return d + if self.positional: + return [d[key] for key in self.positiontup] + else: + return d def visit_column(self, column): if column.table.name is None: @@ -145,7 +180,7 @@ class ANSICompiler(sql.Compiled): self.strings[bindparam] = self.bindparam_string(key) def bindparam_string(self, name): - return ":" + name + return self.bindtemplate % name def visit_alias(self, alias): self.froms[alias] = self.get_from_text(alias.selectable) + " AS " + alias.name @@ -167,7 +202,10 @@ class ANSICompiler(sql.Compiled): else: collist = string.join([c.fullname for c in inner_columns], ', ') - text = "SELECT " + collist + " \nFROM " + text = "SELECT " + if select.distinct: + text += "DISTINCT " + text += collist + " \nFROM " whereclause = select.whereclause