]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 5 Nov 2005 05:17:15 +0000 (05:17 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 5 Nov 2005 05:17:15 +0000 (05:17 +0000)
lib/sqlalchemy/ansisql.py

index 1209b2324eb8f90baace141099110dab2a2ebb75..7205a5e9fb0976f7e5df1a76103c6cbc2d708dde 100644 (file)
@@ -46,9 +46,8 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
     def dbapi(self):
         return object()
 
-
 class ANSICompiler(sql.Compiled):
-    def __init__(self, engine, statement, bindparams, typemap=None, **kwargs):
+    def __init__(self, engine, statement, bindparams, typemap=None, paramstyle=None,**kwargs):
         sql.Compiled.__init__(self, engine, statement, bindparams)
         self.binds = {}
         self.froms = {}
@@ -57,6 +56,39 @@ class ANSICompiler(sql.Compiled):
         self.typemap = typemap or {}
         self.isinsert = False
         
+        if paramstyle is None:
+            paramstyle = self.engine.dbapi().paramstyle
+
+        if paramstyle == 'named':
+            self.bindtemplate = ':%s'
+            self.positional=False
+        elif paramstyle =='pyformat':
+            self.bindtemplate = "%%(%s)s"
+            self.positional=False
+        else:
+            # for positional, use pyformat until the end
+            self.bindtemplate = "%%(%s)s"
+            self.positional=True
+        self.paramstyle=paramstyle
+        
+    def after_compile(self):
+        if self.positional:
+            self.positiontup = []
+            match = r'%\(([\w_]+)\)s'
+            params = re.finditer(match, self.strings[self.statement])
+            for p in params:
+                self.positiontup.append(p.group(1))
+            if self.paramstyle=='qmark':
+                self.strings[self.statement] = re.sub(match, '?', self.strings[self.statement])
+            elif self.paramstyle=='format':
+                self.strings[self.statement] = re.sub(match, '%s', self.strings[self.statement])
+            elif self.paramstyle=='numeric':
+                i = 0
+                def getnum(x):
+                    i += 1
+                    return i
+                self.strings[self.statement] = re.sub(match, getnum(s), self.strings[self.statement])
+
     def get_from_text(self, obj):
         return self.froms[obj]
 
@@ -85,7 +117,10 @@ class ANSICompiler(sql.Compiled):
         for b in self.binds.values():
             d.setdefault(b.key, b.typeprocess(b.value))
 
-        return d
+        if self.positional:
+            return [d[key] for key in self.positiontup]
+        else:
+            return d
 
     def visit_column(self, column):
         if column.table.name is None:
@@ -145,7 +180,7 @@ class ANSICompiler(sql.Compiled):
         self.strings[bindparam] = self.bindparam_string(key)
 
     def bindparam_string(self, name):
-        return ":" + name
+        return self.bindtemplate % name
         
     def visit_alias(self, alias):
         self.froms[alias] = self.get_from_text(alias.selectable) + " AS " + alias.name
@@ -167,7 +202,10 @@ class ANSICompiler(sql.Compiled):
         else:
             collist = string.join([c.fullname for c in inner_columns], ', ')
 
-        text = "SELECT " + collist + " \nFROM "
+        text = "SELECT "
+        if select.distinct:
+            text += "DISTINCT "
+        text += collist + " \nFROM "
         
         whereclause = select.whereclause