]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix to ansisql when it tries to determine param-based select clause that its
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 31 Dec 2005 07:13:18 +0000 (07:13 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 31 Dec 2005 07:13:18 +0000 (07:13 +0000)
only on a column-type object
engine has settable 'paramstyle' attribute

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/sql.py

index 7f30a5da0d4c43104ba8a458dfaea8ec49bf842b..a4cbb43583c15eb961368635830f23178232c364 100644 (file)
@@ -70,11 +70,11 @@ class ANSICompiler(sql.Compiled):
             elif self.engine.paramstyle=='format':
                 self.strings[self.statement] = re.sub(match, '%s', self.strings[self.statement])
             elif self.engine.paramstyle=='numeric':
-                i = 0
+                i = [0]
                 def getnum(x):
-                    i += 1
-                    return i
-                self.strings[self.statement] = re.sub(match, getnum(s), self.strings[self.statement])
+                    i[0] += 1
+                    return str(i[0])
+                self.strings[self.statement] = re.sub(match, getnum, self.strings[self.statement])
 
     def get_from_text(self, obj):
         return self.froms[obj]
@@ -282,7 +282,7 @@ class ANSICompiler(sql.Compiled):
         if self.parameters is not None:
             revisit = False
             for c in inner_columns.values():
-                if self.parameters.has_key(c.key) and not self.binds.has_key(c.key):
+                if sql.is_column(c) and self.parameters.has_key(c.key) and not self.binds.has_key(c.key):
                     value = self.parameters[c.key]
                 else:
                     continue
index bc39d90d970c64995da090b79afb5ff89ec2a155..f237014247d0d387fa342f5155ee66f8c1e9d1ed 100644 (file)
@@ -176,23 +176,33 @@ class SQLEngine(schema.SchemaEngine):
         else:
             self.logger = logger
     
-    def _figure_paramstyle(self):
+    def _set_paramstyle(self, style):
+        self._paramstyle = style
+        self._figure_paramstyle(style)
+    paramstyle = property(lambda s:s._paramstyle, _set_paramstyle)
+    
+    def _figure_paramstyle(self, paramstyle=None):
         db = self.dbapi()
-        if db is not None:
-            self.paramstyle = db.paramstyle
+        if paramstyle is not None:
+            self._paramstyle = paramstyle
+        elif db is not None:
+            self._paramstyle = db.paramstyle
         else:
-            self.paramstyle = 'named'
+            self._paramstyle = 'named'
 
-        if self.paramstyle == 'named':
+        if self._paramstyle == 'named':
             self.bindtemplate = ':%s'
             self.positional=False
-        elif self.paramstyle =='pyformat':
+        elif self._paramstyle == 'pyformat':
             self.bindtemplate = "%%(%s)s"
             self.positional=False
-        else:
-            # for positional, use pyformat until the end
+        elif self._paramstyle == 'qmark' or self._paramstyle == 'format' or self._paramstyle == 'numeric':
+            # for positional, use pyformat internally, ANSICompiler will convert
+            # to appropriate character upon compilation
             self.bindtemplate = "%%(%s)s"
-            self.positional=True
+            self.positional = True
+        else:
+            raise "Unsupported paramstyle '%s'" % self._paramstyle
         
     def type_descriptor(self, typeobj):
         """provides a database-specific TypeEngine object, given the generic object
index 1920d4ff2e06a270c24346b123d0bbf9b8563007..27479ae9d62ccaa64b0258341935b7a7f3f6f957 100644 (file)
@@ -117,10 +117,10 @@ class Mapper(object):
         # load custom properties 
         if properties is not None:
             for key, prop in properties.iteritems():
-                if is_column(prop):
+                if sql.is_column(prop):
                     self.columns[key] = prop
                     prop = ColumnProperty(prop)
-                elif isinstance(prop, list) and is_column(prop[0]):
+                elif isinstance(prop, list) and sql.is_column(prop[0]):
                     self.columns[key] = prop[0]
                     prop = ColumnProperty(*prop)
                 self.props[key] = prop
@@ -170,7 +170,7 @@ class Mapper(object):
 
     def add_property(self, key, prop):
         self.copyargs['properties'][key] = prop
-        if is_column(prop):
+        if sql.is_column(prop):
             self.columns[key] = prop
             prop = ColumnProperty(prop)
         self.props[key] = prop
@@ -797,9 +797,6 @@ def hash_key(obj):
         return obj.hash_key()
     else:
         return repr(obj)
-
-def is_column(col):
-    return isinstance(col, schema.Column) or isinstance(col, sql.ColumnElement)
     
 def mapper_hash_key(class_, table, primarytable = None, properties = None, **kwargs):
     if properties is None:
index 6e10242510c3f6dda483de9b7ebb706a923583a5..d7b1ac021887ffb0d60ffc85f593b83595687a74 100644 (file)
@@ -201,6 +201,9 @@ def _compound_select(keyword, *selects, **kwargs):
 def _is_literal(element):
     return not isinstance(element, ClauseElement) and not isinstance(element, schema.SchemaItem)
 
+def is_column(col):
+    return isinstance(col, schema.Column) or isinstance(col, ColumnElement)
+
 class ClauseVisitor(schema.SchemaVisitor):
     """builds upon SchemaVisitor to define the visiting of SQL statement elements in 
     addition to Schema elements."""