]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- mxodbc can use default execute() call
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 17 Mar 2010 00:58:46 +0000 (20:58 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 17 Mar 2010 00:58:46 +0000 (20:58 -0400)
- modified SQLCompiler to support rendering of bind parameters as literal
inline strings for specific sections, if specified by the compiler
subclass, using either literal_binds=True passed to process() or any visit
method, or by setting to False the "binds_in_columns_clause" flag for SQL-92
compatible columns clauses..  The compiler subclass is responsible for
implementing the literal quoting function which should make use of the DBAPI's native
capabilities.
- SQLCompiler now passes **kw to most process() methods (should be all,
ideally) so that literal_binds is propagated.
- added some rudimentary tests for mxodbc.

lib/sqlalchemy/connectors/mxodbc.py
lib/sqlalchemy/dialects/maxdb/base.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mssql/mxodbc.py
lib/sqlalchemy/dialects/sybase/base.py
lib/sqlalchemy/sql/compiler.py
test/dialect/test_mssql.py

index 29b047d23b8c2b805e5263f48a9c535d27613ac1..68b88019c248ab221acc8b3d314d96387fcafbaf 100644 (file)
@@ -96,9 +96,4 @@ class MxODBCConnector(Connector):
                 version.append(n)
         return tuple(version)
     
-    def do_execute(self, cursor, statement, parameters, context=None):
-        # TODO: dont need tuple() here
-        # TODO: use cursor.execute()
-        cursor.executedirect(statement, tuple(parameters))
-
 
index 504c31209d428fffb2d57dfd12eafe1cef47c0e5..758cfaf052d3659c1346b25579daf7cbdf1f388e 100644 (file)
@@ -558,8 +558,8 @@ class MaxDBCompiler(compiler.SQLCompiler):
 
         return labels
 
-    def order_by_clause(self, select):
-        order_by = self.process(select._order_by_clause)
+    def order_by_clause(self, select, **kw):
+        order_by = self.process(select._order_by_clause, **kw)
 
         # ORDER BY clauses in DISTINCT queries must reference aliased
         # inner columns by alias name, not true column name.
index 254aa54fd3bea7ec0e92e8f019b0d66a94499720..eb4073b9404327edd7cc73ac4b531e2e028776e4 100644 (file)
@@ -1011,8 +1011,8 @@ class MSSQLCompiler(compiler.SQLCompiler):
         # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
         return ''
 
-    def order_by_clause(self, select):
-        order_by = self.process(select._order_by_clause)
+    def order_by_clause(self, select, **kw):
+        order_by = self.process(select._order_by_clause, **kw)
 
         # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
         if order_by and (not self.is_subquery() or select._limit):
index bf14601b871c68d454b941564290292e1ba5f81e..3dcc78b8c97d5f64e327291e96702301faaf89d0 100644 (file)
@@ -4,9 +4,44 @@ import sys
 from sqlalchemy import types as sqltypes
 from sqlalchemy.connectors.mxodbc import MxODBCConnector
 from sqlalchemy.dialects.mssql.pyodbc import MSExecutionContext_pyodbc
-from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect
+from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect, MSSQLCompiler
 
+# TODO: does Pyodbc on windows have the same limitations ?
+# if so this compiler can be moved to a common "odbc.py" module
+# here
+# *or* - should we implement this for MS-SQL across the board 
+# since its technically MS-SQL's behavior ?
+# perhaps yes, with a dialect flag "strict_binds" to turn it off
+class MSSQLCompiler_mxodbc(MSSQLCompiler):
+    binds_in_columns_clause = False
+    
+    def visit_in_op(self, binary, **kw):
+        kw['literal_binds'] = True
+        return "%s IN %s" % (
+                                self.process(binary.left, **kw), 
+                                self.process(binary.right, **kw)
+            )
 
+    def visit_notin_op(self, binary, **kw):
+        kw['literal_binds'] = True
+        return "%s NOT IN %s" % (
+                                self.process(binary.left, **kw), 
+                                self.process(binary.right, **kw)
+            )
+        
+    def visit_function(self, func, **kw):
+        kw['literal_binds'] = True
+        return super(MSSQLCompiler_mxodbc, self).visit_function(func, **kw)
+    
+    def render_literal_value(self, value):
+        # TODO! use mxODBC's literal quoting services here
+        if isinstance(value, basestring):
+            value = value.replace("'", "''")
+            return "'%s'" % value
+        else:
+            return repr(value)
+        
+        
 class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
     """
     The pyodbc execution context is useful for enabling
@@ -20,7 +55,11 @@ class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
 class MSDialect_mxodbc(MxODBCConnector, MSDialect):
 
     execution_ctx_cls = MSExecutionContext_mxodbc
-
+    
+    # TODO: may want to use this only if FreeTDS is not in use,
+    # since FreeTDS doesn't seem to use native binds.
+    statement_compiler = MSSQLCompiler_mxodbc
+    
     def __init__(self, description_encoding='latin-1', **params):
         super(MSDialect_mxodbc, self).__init__(**params)
         self.description_encoding = description_encoding
index b3ac455588b5bf1f762f3d0a59bcad9562009a8c..2addba2f89a9db624603ae9f6cffb38b345e5931 100644 (file)
@@ -265,8 +265,8 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
         # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
         return ''
 
-    def order_by_clause(self, select):
-        order_by = self.process(select._order_by_clause)
+    def order_by_clause(self, select, **kw):
+        order_by = self.process(select._order_by_clause, **kw)
 
         # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
         if order_by and (not self.is_subquery() or select._limit):
index be3375def5e08bd862a4b8861765af6055ea8840..a3008d085fd70eafd05773cb7aff37becf62e7e3 100644 (file)
@@ -183,6 +183,11 @@ class SQLCompiler(engine.Compiled):
     # clauses before the VALUES or WHERE clause (i.e. MSSQL)
     returning_precedes_values = False
     
+    # SQL 92 doesn't allow bind parameters to be used
+    # in the columns clause of a SELECT.  A compiler
+    # subclass can set this flag to False if the target
+    # driver/DB enforces this
+    binds_in_columns_clause = True
     
     def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
         """Construct a new ``DefaultCompiler`` object.
@@ -260,9 +265,14 @@ class SQLCompiler(engine.Compiled):
                 else:
                     if bindparam.required:
                         if _group_number:
-                            raise exc.InvalidRequestError("A value is required for bind parameter %r, in parameter group %d" % (bindparam.key, _group_number))
+                            raise exc.InvalidRequestError(
+                                            "A value is required for bind parameter %r, "
+                                            "in parameter group %d" % 
+                                            (bindparam.key, _group_number))
                         else:
-                            raise exc.InvalidRequestError("A value is required for bind parameter %r" % bindparam.key)
+                            raise exc.InvalidRequestError(
+                                            "A value is required for bind parameter %r" 
+                                            % bindparam.key)
                     elif util.callable(bindparam.value):
                         pd[name] = bindparam.value()
                     else:
@@ -290,8 +300,8 @@ class SQLCompiler(engine.Compiled):
         """
         return ""
 
-    def visit_grouping(self, grouping, **kwargs):
-        return "(" + self.process(grouping.element) + ")"
+    def visit_grouping(self, grouping, asfrom=False, **kwargs):
+        return "(" + self.process(grouping.element, **kwargs) + ")"
 
     def visit_label(self, label, result_map=None, within_columns_clause=False):
         # only render labels within the columns clause
@@ -384,27 +394,28 @@ class SQLCompiler(engine.Compiled):
             sep = " "
         else:
             sep = OPERATORS[clauselist.operator]
-        return sep.join(s for s in (self.process(c) for c in clauselist.clauses)
+        return sep.join(s for s in (self.process(c, **kwargs) for c in clauselist.clauses)
                         if s is not None)
 
     def visit_case(self, clause, **kwargs):
         x = "CASE "
         if clause.value is not None:
-            x += self.process(clause.value) + " "
+            x += self.process(clause.value, **kwargs) + " "
         for cond, result in clause.whens:
-            x += "WHEN " + self.process(cond) + " THEN " + self.process(result) + " "
+            x += "WHEN " + self.process(cond, **kwargs) + \
+                            " THEN " + self.process(result, **kwargs) + " "
         if clause.else_ is not None:
-            x += "ELSE " + self.process(clause.else_) + " "
+            x += "ELSE " + self.process(clause.else_, **kwargs) + " "
         x += "END"
         return x
 
     def visit_cast(self, cast, **kwargs):
         return "CAST(%s AS %s)" % \
-                    (self.process(cast.clause), self.process(cast.typeclause))
+                    (self.process(cast.clause, **kwargs), self.process(cast.typeclause, **kwargs))
 
     def visit_extract(self, extract, **kwargs):
         field = self.extract_map.get(extract.field, extract.field)
-        return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr))
+        return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr, **kwargs))
 
     def visit_function(self, func, result_map=None, **kwargs):
         if result_map is not None:
@@ -421,22 +432,23 @@ class SQLCompiler(engine.Compiled):
     def function_argspec(self, func, **kwargs):
         return self.process(func.clause_expr, **kwargs)
 
-    def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
+    def visit_compound_select(self, cs, asfrom=False, parens=True, compound_index=1, **kwargs):
         entry = self.stack and self.stack[-1] or {}
         self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
 
         keyword = self.compound_keywords.get(cs.keyword)
         
         text = (" " + keyword + " ").join(
-                            (self.process(c, asfrom=asfrom, parens=False, compound_index=i)
+                            (self.process(c, asfrom=asfrom, parens=False, 
+                                            compound_index=i, **kwargs)
                             for i, c in enumerate(cs.selects))
                         )
                         
-        group_by = self.process(cs._group_by_clause, asfrom=asfrom)
+        group_by = self.process(cs._group_by_clause, asfrom=asfrom, **kwargs)
         if group_by:
             text += " GROUP BY " + group_by
 
-        text += self.order_by_clause(cs)
+        text += self.order_by_clause(cs, **kwargs)
         text += (cs._limit is not None or cs._offset is not None) and self.limit_clause(cs) or ""
 
         self.stack.pop(-1)
@@ -457,28 +469,38 @@ class SQLCompiler(engine.Compiled):
         
         return self._operator_dispatch(binary.operator,
                     binary,
-                    lambda opstr: self.process(binary.left) + opstr + self.process(binary.right),
+                    lambda opstr: self.process(binary.left, **kwargs) + 
+                                        opstr + 
+                                    self.process(binary.right, **kwargs),
                     **kwargs
         )
 
     def visit_like_op(self, binary, **kw):
         escape = binary.modifiers.get("escape", None)
-        return '%s LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+        return '%s LIKE %s' % (
+                                    self.process(binary.left, **kw), 
+                                    self.process(binary.right, **kw)) \
             + (escape and ' ESCAPE \'%s\'' % escape or '')
 
     def visit_notlike_op(self, binary, **kw):
         escape = binary.modifiers.get("escape", None)
-        return '%s NOT LIKE %s' % (self.process(binary.left), self.process(binary.right)) \
+        return '%s NOT LIKE %s' % (
+                                    self.process(binary.left, **kw), 
+                                    self.process(binary.right, **kw)) \
             + (escape and ' ESCAPE \'%s\'' % escape or '')
         
     def visit_ilike_op(self, binary, **kw):
         escape = binary.modifiers.get("escape", None)
-        return 'lower(%s) LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+        return 'lower(%s) LIKE lower(%s)' % (
+                                            self.process(binary.left, **kw), 
+                                            self.process(binary.right, **kw)) \
             + (escape and ' ESCAPE \'%s\'' % escape or '')
     
     def visit_notilike_op(self, binary, **kw):
         escape = binary.modifiers.get("escape", None)
-        return 'lower(%s) NOT LIKE lower(%s)' % (self.process(binary.left), self.process(binary.right)) \
+        return 'lower(%s) NOT LIKE lower(%s)' % (
+                                            self.process(binary.left, **kw), 
+                                            self.process(binary.right, **kw)) \
             + (escape and ' ESCAPE \'%s\'' % escape or '')
         
     def _operator_dispatch(self, operator, element, fn, **kw):
@@ -491,7 +513,14 @@ class SQLCompiler(engine.Compiled):
         else:
             return fn(" " + operator + " ")
         
-    def visit_bindparam(self, bindparam, **kwargs):
+    def visit_bindparam(self, bindparam, within_columns_clause=False, 
+                                            literal_binds=False, **kwargs):
+        if literal_binds or \
+            (within_columns_clause and \
+                not self.binds_in_columns_clause) and \
+                    bindparam.value is not None:
+            return self.render_literal_bindparam(bindparam, within_columns_clause=True, **kwargs)
+            
         name = self._truncate_bindparam(bindparam)
         if name in self.binds:
             existing = self.binds[name]
@@ -510,7 +539,26 @@ class SQLCompiler(engine.Compiled):
                     
         self.binds[bindparam.key] = self.binds[name] = bindparam
         return self.bindparam_string(name)
-
+    
+    def render_literal_bindparam(self, bindparam, **kw):
+        value = bindparam.value
+        processor = bindparam.bind_processor(self.dialect)
+        if processor:
+            value = processor(value)
+        return self.render_literal_value(value)
+        
+    def render_literal_value(self, value):
+        """Render the value of a bind parameter as a quoted literal.
+        
+        This is used for statement sections that do not accept bind paramters
+        on the target driver/database.
+        
+        This should be implemented by subclasses using the quoting services
+        of the DBAPI.
+        
+        """
+        raise NotImplementedError()
+        
     def _truncate_bindparam(self, bindparam):
         if bindparam in self.bind_names:
             return self.bind_names[bindparam]
@@ -624,33 +672,33 @@ class SQLCompiler(engine.Compiled):
         
         text = "SELECT "  # we're off to a good start !
         if select._prefixes:
-            text += " ".join(self.process(x) for x in select._prefixes) + " "
+            text += " ".join(self.process(x, **kwargs) for x in select._prefixes) + " "
         text += self.get_select_precolumns(select)
         text += ', '.join(inner_columns)
 
         if froms:
             text += " \nFROM "
-            text += ', '.join(self.process(f, asfrom=True) for f in froms)
+            text += ', '.join(self.process(f, asfrom=True, **kwargs) for f in froms)
         else:
             text += self.default_from()
 
         if select._whereclause is not None:
-            t = self.process(select._whereclause)
+            t = self.process(select._whereclause, **kwargs)
             if t:
                 text += " \nWHERE " + t
 
         if select._group_by_clause.clauses:
-            group_by = self.process(select._group_by_clause)
+            group_by = self.process(select._group_by_clause, **kwargs)
             if group_by:
                 text += " GROUP BY " + group_by
 
         if select._having is not None:
-            t = self.process(select._having)
+            t = self.process(select._having, **kwargs)
             if t:
                 text += " \nHAVING " + t
 
         if select._order_by_clause.clauses:
-            text += self.order_by_clause(select)
+            text += self.order_by_clause(select, **kwargs)
         if select._limit is not None or select._offset is not None:
             text += self.limit_clause(select)
         if select.for_update:
@@ -670,8 +718,8 @@ class SQLCompiler(engine.Compiled):
         """
         return select._distinct and "DISTINCT " or ""
 
-    def order_by_clause(self, select):
-        order_by = self.process(select._order_by_clause)
+    def order_by_clause(self, select, **kw):
+        order_by = self.process(select._order_by_clause, **kw)
         if order_by:
             return " ORDER BY " + order_by
         else:
index 89a3af5fb55e54162339b252f63cef03b7d9838c..8092d8cdc877b43a0112632422bc51d71aeecfc6 100644 (file)
@@ -6,7 +6,7 @@ from sqlalchemy import types, exc, schema
 from sqlalchemy.orm import *
 from sqlalchemy.sql import table, column
 from sqlalchemy.databases import mssql
-from sqlalchemy.dialects.mssql import pyodbc
+from sqlalchemy.dialects.mssql import pyodbc, mxodbc
 from sqlalchemy.engine import url
 from sqlalchemy.test import *
 from sqlalchemy.test.testing import eq_, emits_warning_on
@@ -22,7 +22,35 @@ class CompileTest(TestBase, AssertsCompiledSQL):
     def test_update(self):
         t = table('sometable', column('somecolumn'))
         self.assert_compile(t.update(t.c.somecolumn==7), "UPDATE sometable SET somecolumn=:somecolumn WHERE sometable.somecolumn = :somecolumn_1", dict(somecolumn=10))
-
+    
+    # TODO: should this be for *all* MS-SQL dialects ?
+    def test_mxodbc_binds(self):
+        """mxodbc uses MS-SQL native binds, which aren't allowed in various places."""
+        
+        mxodbc_dialect = mxodbc.dialect()
+        t = table('sometable', column('foo'))
+        
+        for expr, compile in [
+            (
+                select([literal("x"), literal("y")]), 
+                "SELECT 'x', 'y'",
+            ),
+            (
+                select([t]).where(t.c.foo.in_(['x', 'y', 'z'])),
+                "SELECT sometable.foo FROM sometable WHERE sometable.foo IN ('x', 'y', 'z')",
+            ),
+            (
+                func.foobar("x", "y", 4, 5),
+                "foobar('x', 'y', 4, 5)",
+            ),
+            (
+                select([t]).where(func.len('xyz') > func.len(t.c.foo)),
+                "SELECT sometable.foo FROM sometable WHERE len('xyz') > len(sometable.foo)",
+            )
+        ]:
+            self.assert_compile(expr, compile, dialect=mxodbc_dialect)
+        
+        
     def test_in_with_subqueries(self):
         """Test that when using subqueries in a binary expression
         the == and != are changed to IN and NOT IN respectively.
@@ -127,15 +155,24 @@ class CompileTest(TestBase, AssertsCompiledSQL):
             column('col4'))
 
         (s1, s2) = (
-                    select([t1.c.col3.label('col3'), t1.c.col4.label('col4')], t1.c.col2.in_(["t1col2r1", "t1col2r2"])),
-            select([t2.c.col3.label('col3'), t2.c.col4.label('col4')], t2.c.col2.in_(["t2col2r2", "t2col2r3"]))
+                    select([t1.c.col3.label('col3'), t1.c.col4.label('col4')],
+                            t1.c.col2.in_(["t1col2r1", "t1col2r2"])),
+            select([t2.c.col3.label('col3'), t2.c.col4.label('col4')], 
+                            t2.c.col2.in_(["t2col2r2", "t2col2r3"]))
         )
         u = union(s1, s2, order_by=['col3', 'col4'])
-        self.assert_compile(u, "SELECT t1.col3 AS col3, t1.col4 AS col4 FROM t1 WHERE t1.col2 IN (:col2_1, :col2_2) "\
-        "UNION SELECT t2.col3 AS col3, t2.col4 AS col4 FROM t2 WHERE t2.col2 IN (:col2_3, :col2_4) ORDER BY col3, col4")
-
-        self.assert_compile(u.alias('bar').select(), "SELECT bar.col3, bar.col4 FROM (SELECT t1.col3 AS col3, t1.col4 AS col4 FROM t1 WHERE "\
-        "t1.col2 IN (:col2_1, :col2_2) UNION SELECT t2.col3 AS col3, t2.col4 AS col4 FROM t2 WHERE t2.col2 IN (:col2_3, :col2_4)) AS bar")
+        self.assert_compile(u, 
+                "SELECT t1.col3 AS col3, t1.col4 AS col4 FROM t1 WHERE t1.col2 IN "
+                "(:col2_1, :col2_2) "\
+                "UNION SELECT t2.col3 AS col3, t2.col4 AS col4 FROM t2 WHERE t2.col2 "
+                "IN (:col2_3, :col2_4) ORDER BY col3, col4")
+
+        self.assert_compile(u.alias('bar').select(), 
+                                "SELECT bar.col3, bar.col4 FROM (SELECT t1.col3 AS col3, "
+                                "t1.col4 AS col4 FROM t1 WHERE "\
+                                "t1.col2 IN (:col2_1, :col2_2) UNION SELECT t2.col3 AS col3, "
+                                "t2.col4 AS col4 FROM t2 WHERE t2.col2 IN (:col2_3, :col2_4)) "
+                                "AS bar")
 
     def test_function(self):
         self.assert_compile(func.foo(1, 2), "foo(:foo_1, :foo_2)")