]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
removed the dependency of ANSICompiler on SQLEngine. you can now make ANSICompilers...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Mar 2006 18:53:35 +0000 (18:53 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Mar 2006 18:53:35 +0000 (18:53 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/information_schema.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/sql.py
test/select.py

index 1b600a4a847cdaa00177a972d977198842c4634c..7c0002aa58958cb89f94689b95641f8d049ac91a 100644 (file)
@@ -27,7 +27,7 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
         return ANSISchemaDropper(self, **params)
 
     def compiler(self, statement, parameters, **kwargs):
-        return ANSICompiler(self, statement, parameters, **kwargs)
+        return ANSICompiler(statement, parameters, engine=self, **kwargs)
     
     def connect_args(self):
         return ([],{})
@@ -37,7 +37,7 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
 
 class ANSICompiler(sql.Compiled):
     """default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings."""
-    def __init__(self, engine, statement, parameters=None, typemap=None, **kwargs):
+    def __init__(self, statement, parameters=None, typemap=None, engine=None, positional=None, paramstyle=None, **kwargs):
         """constructs a new ANSICompiler object.
         
         engine - SQLEngine to compile against
@@ -49,7 +49,7 @@ class ANSICompiler(sql.Compiled):
         key/value pairs when the Compiled is executed, and also may affect the 
         actual compilation, as in the case of an INSERT where the actual columns
         inserted will correspond to the keys present in the parameters."""
-        sql.Compiled.__init__(self, engine, statement, parameters)
+        sql.Compiled.__init__(self, statement, parameters, engine=engine)
         self.binds = {}
         self.froms = {}
         self.wheres = {}
@@ -57,19 +57,31 @@ class ANSICompiler(sql.Compiled):
         self.select_stack = []
         self.typemap = typemap or {}
         self.isinsert = False
+        self.bindtemplate = ":%s"
+        if engine is not None:
+            self.paramstyle = engine.paramstyle
+            self.positional = engine.positional
+        else:
+            self.positional = False
+            self.paramstyle = 'named'
         
     def after_compile(self):
-        if self.engine.positional:
+        # this re will search for params like :param
+        # it has a negative lookbehind for an extra ':' so that it doesnt match
+        # postgres '::text' tokens
+        match = r'(?<!:):([\w_]+)'
+        if self.paramstyle=='pyformat':
+            self.strings[self.statement] = re.sub(match, lambda m:'%(' + m.group(1) +')s', self.strings[self.statement])
+        elif self.positional:
             self.positiontup = []
-            match = r'%\(([\w_]+)\)s'
             params = re.finditer(match, self.strings[self.statement])
             for p in params:
                 self.positiontup.append(p.group(1))
-            if self.engine.paramstyle=='qmark':
+            if self.paramstyle=='qmark':
                 self.strings[self.statement] = re.sub(match, '?', self.strings[self.statement])
-            elif self.engine.paramstyle=='format':
+            elif self.paramstyle=='format':
                 self.strings[self.statement] = re.sub(match, '%s', self.strings[self.statement])
-            elif self.engine.paramstyle=='numeric':
+            elif self.paramstyle=='numeric':
                 i = [0]
                 def getnum(x):
                     i[0] += 1
@@ -104,28 +116,33 @@ class ANSICompiler(sql.Compiled):
             bindparams = {}
         bindparams.update(params)
 
-        if self.engine.positional:
+        if self.positional:
             d = OrderedDict()
             for k in self.positiontup:
                 b = self.binds[k]
-                d[k] = b.typeprocess(b.value, self.engine)
+                if self.engine is not None:
+                    d[k] = b.typeprocess(b.value, self.engine)
+                else:
+                    d[k] = b.value
         else:
             d = {}
             for b in self.binds.values():
-                d[b.key] = b.typeprocess(b.value, self.engine)
+                if self.engine is not None:
+                    d[b.key] = b.typeprocess(b.value, self.engine)
+                else:
+                    d[b.key] = b.value
             
         for key, value in bindparams.iteritems():
             try:
                 b = self.binds[key]
             except KeyError:
                 continue
-            d[b.key] = b.typeprocess(value, self.engine)
+            if self.engine is not None:
+                d[b.key] = b.typeprocess(value, self.engine)
+            else:
+                d[b.key] = value
 
         return d
-        if self.engine.positional:
-            return d.values()
-        else:
-            return d
 
     def get_named_params(self, parameters):
         """given the results of the get_params method, returns the parameters
@@ -133,8 +150,7 @@ class ANSICompiler(sql.Compiled):
         same dictionary.  For a positional paramstyle, the given parameters are
         assumed to be in list format and are converted back to a dictionary.
         """
-#        return parameters
-        if self.engine.positional:
+        if self.positional:
             p = {}
             for i in range(0, len(self.positiontup)):
                 p[self.positiontup[i]] = parameters[i]
@@ -231,7 +247,7 @@ class ANSICompiler(sql.Compiled):
         self.strings[bindparam] = self.bindparam_string(key)
 
     def bindparam_string(self, name):
-        return self.engine.bindtemplate % name
+        return self.bindtemplate % name
         
     def visit_alias(self, alias):
         self.froms[alias] = self.get_from_text(alias.original) + " AS " + alias.name
index 2509148130c702fd64d16297cc505c0b57b45dbb..4dd4aa2a67ceb230f4dc352749192476b80ac914 100644 (file)
@@ -102,7 +102,7 @@ class FBSQLEngine(ansisql.ANSISQLEngine):
         return self.context.last_inserted_ids
 
     def compiler(self, statement, bindparams, **kwargs):
-        return FBCompiler(self, statement, bindparams, use_ansi=self._use_ansi, **kwargs)
+        return FBCompiler(statement, bindparams, engine=self, use_ansi=self._use_ansi, **kwargs)
 
     def schemagenerator(self, **params):
         return FBSchemaGenerator(self, **params)
index 825e0017a804bca03e763847e68e83e78e2c4b12..feb3cf0c52cf0c3c588ff9c504b0a7e049c9468a 100644 (file)
@@ -132,7 +132,7 @@ def reflecttable(engine, table, ischema_names, use_mysql=False):
         coltype = coltype(*args)
         colargs= []
         if default is not None:
-            colargs.append(PassiveDefault(sql.text(default, escape=False)))
+            colargs.append(PassiveDefault(sql.text(default)))
         table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs))
 
     s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True, from_obj=[constraints.join(column_constraints, column_constraints.c.constraint_name==constraints.c.constraint_name).join(key_constraints, key_constraints.c.constraint_name==column_constraints.c.constraint_name)])
index b29078ca2e07ffef3abb6e4ab564586b7deb218c..8e305a697b0cd0239a6ec68ebc3de7462a32557b 100644 (file)
@@ -132,7 +132,7 @@ class MySQLEngine(ansisql.ANSISQLEngine):
         return False
 
     def compiler(self, statement, bindparams, **kwargs):
-        return MySQLCompiler(self, statement, bindparams, **kwargs)
+        return MySQLCompiler(statement, bindparams, engine=self, **kwargs)
 
     def schemagenerator(self, **params):
         return MySQLSchemaGenerator(self, **params)
index b26298c7780cab1465258513611fdf8eddab844e..6f5e98265cadd352fbf4dbaa3a5f73b7bde43671 100644 (file)
@@ -151,7 +151,7 @@ class OracleSQLEngine(ansisql.ANSISQLEngine):
                
             colargs = []
             if default is not None:
-                colargs.append(PassiveDefault(sql.text(default, escape=False)))
+                colargs.append(PassiveDefault(sql.text(default)))
             
             name = name.lower()
             
@@ -207,7 +207,7 @@ class OracleCompiler(ansisql.ANSICompiler):
     def __init__(self, engine, statement, parameters, use_ansi = True, **kwargs):
         self._outertable = None
         self._use_ansi = use_ansi
-        ansisql.ANSICompiler.__init__(self, engine, statement, parameters, **kwargs)
+        ansisql.ANSICompiler.__init__(self, statement, parameters, engine=engine, **kwargs)
         
     def visit_join(self, join):
         if self._use_ansi:
index 92407637f940e1aacc1499c9216fed119ee1328d..592bac79c8b4af46bd25f5e18ddf182a9fd37da2 100644 (file)
@@ -210,7 +210,7 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
             return sqltypes.adapt_type(typeobj, pg1_colspecs)
 
     def compiler(self, statement, bindparams, **kwargs):
-        return PGCompiler(self, statement, bindparams, **kwargs)
+        return PGCompiler(statement, bindparams, engine=self, **kwargs)
 
     def schemagenerator(self, **params):
         return PGSchemaGenerator(self, **params)
index 2e366e43246d8aeedef81cdc1055cab2a3f83933..6dc880b05150af4e17c7608ad497d2b759bfefdb 100644 (file)
@@ -147,7 +147,7 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine):
         return ([self.filename], self.opts)
 
     def compiler(self, statement, bindparams, **kwargs):
-        return SQLiteCompiler(self, statement, bindparams, **kwargs)
+        return SQLiteCompiler(statement, bindparams, engine=self, **kwargs)
 
     def dbapi(self):
         return sqlite
index e57cc7bc323ca3f9b53771e6a4575ba03579ab32..d07dd57341e18aedf9ce5ef88acb87179654fdd9 100644 (file)
@@ -227,15 +227,12 @@ class SQLEngine(schema.SchemaEngine):
             self._paramstyle = 'named'
 
         if self._paramstyle == 'named':
-            self.bindtemplate = ':%s'
             self.positional=False
         elif self._paramstyle == 'pyformat':
-            self.bindtemplate = "%%(%s)s"
             self.positional=False
         elif self._paramstyle == 'qmark' or self._paramstyle == 'format' or self._paramstyle == 'numeric':
             # for positional, use pyformat internally, ANSICompiler will convert
             # to appropriate character upon compilation
-            self.bindtemplate = "%%(%s)s"
             self.positional = True
         else:
             raise DBAPIError("Unsupported paramstyle '%s'" % self._paramstyle)
@@ -310,8 +307,7 @@ class SQLEngine(schema.SchemaEngine):
         instance of this engine's SQLCompiler, compiles the ClauseElement, and returns the
         newly compiled object."""
         compiler = self.compiler(statement, parameters, **kwargs)
-        statement.accept_visitor(compiler)
-        compiler.after_compile()
+        compiler.compile()
         return compiler
 
     def reflecttable(self, table):
index 6b574861be8d0f9166cb7eebc4030237125f3136..f05310e425ccc125ce815fbd111bd4cd0fd8969b 100644 (file)
@@ -257,11 +257,9 @@ class Compiled(ClauseVisitor):
     object be dependent on the actual values of those bind parameters, even though it may
     reference those values as defaults."""
 
-    def __init__(self, engine, statement, parameters):
+    def __init__(self, statement, parameters, engine=None):
         """constructs a new Compiled object.
         
-        engine - SQLEngine to compile against
-        
         statement - ClauseElement to be compiled
         
         parameters - optional dictionary indicating a set of bind parameters
@@ -271,10 +269,12 @@ class Compiled(ClauseVisitor):
         will also result in the creation of new BindParamClause objects for each key
         and will also affect the generated column list in an INSERT statement and the SET 
         clauses of an UPDATE statement.  The keys of the parameter dictionary can
-        either be the string names of columns or actual sqlalchemy.schema.Column objects."""
-        self.engine = engine
+        either be the string names of columns or ColumnClause objects.
+        
+        engine - optional SQLEngine to compile this statement against"""
         self.parameters = parameters
         self.statement = statement
+        self.engine = engine
 
     def __str__(self):
         """returns the string text of the generated SQL statement."""
@@ -290,6 +290,10 @@ class Compiled(ClauseVisitor):
         """
         raise NotImplementedError()
 
+    def compile(self):
+        self.statement.accept_visitor(self)
+        self.after_compile()
+
     def execute(self, *multiparams, **params):
         """executes this compiled object using the underlying SQLEngine"""
         if len(multiparams):
@@ -367,20 +371,25 @@ class ClauseElement(object):
             return None
             
     engine = property(lambda s: s._find_engine(), doc="attempts to locate a SQLEngine within this ClauseElement structure, or returns None if none found.")
-    
-    def compile(self, engine = None, parameters = None, typemap=None):
+
+
+    def compile(self, engine = None, parameters = None, typemap=None, compiler=None):
         """compiles this SQL expression using its underlying SQLEngine to produce
         a Compiled object.  If no engine can be found, an ansisql engine is used.
         bindparams is a dictionary representing the default bind parameters to be used with 
         the statement.  """
-        if engine is None:
-            engine = self.engine
-
-        if engine is None:
+        
+        if compiler is None:
+            if engine is not None:
+                compiler = engine.compiler(self, parameters)
+            elif self.engine is not None:
+                compiler = self.engine.compiler(self, parameters)
+                
+        if compiler is None:
             import sqlalchemy.ansisql as ansisql
-            engine = ansisql.engine()
-
-        return engine.compile(self, parameters=parameters, typemap=typemap)
+            compiler = ansisql.ANSICompiler(self, parameters=parameters, typemap=typemap)
+        compiler.compile()
+        return compiler
 
     def __str__(self):
         return str(self.compile())
@@ -638,7 +647,7 @@ class TextClause(ClauseElement):
     being specified as a bind parameter via the bindparam() method,
     since it provides more information about what it is, including an optional
     type, as well as providing comparison operations."""
-    def __init__(self, text = "", engine=None, bindparams=None, typemap=None, escape=True):
+    def __init__(self, text = "", engine=None, bindparams=None, typemap=None):
         self.parens = False
         self._engine = engine
         self.id = id(self)
@@ -649,12 +658,10 @@ class TextClause(ClauseElement):
                 typemap[key] = engine.type_descriptor(typemap[key])
         def repl(m):
             self.bindparams[m.group(1)] = bindparam(m.group(1))
-            return self.engine.bindtemplate % m.group(1)
-        
-        if escape: 
-            self.text = re.compile(r':([\w_]+)', re.S).sub(repl, text)
-        else:
-            self.text = text
+            return ":%s" % m.group(1)
+        # scan the string and search for bind parameter names, add them 
+        # to the list of bindparams
+        self.text = re.compile(r'(?<!:):([\w_]+)', re.S).sub(repl, text)
         if bindparams is not None:
             for b in bindparams:
                 self.bindparams[b.key] = b
index eb21022108b1538cb0f466a977187d76d92e6b5c..20454a9fc4502a85bfd5242a52e403f977da1e30 100644 (file)
@@ -57,15 +57,16 @@ addresses = table('addresses',
 
 class SQLTest(PersistTest):
     def runtest(self, clause, result, engine = None, params = None, checkparams = None):
-        c = clause.compile(parameters = params, engine=engine)
+        print "TEST with e", engine
+        c = clause.compile(parameters=params, engine=engine)
         self.echo("\nSQL String:\n" + str(c) + repr(c.get_params()))
         cc = re.sub(r'\n', '', str(c))
         self.assert_(cc == result, str(c) + "\n does not match \n" + result)
         if checkparams is not None:
             if isinstance(checkparams, list):
-                self.assert_(c.get_params().values() == checkparams, "params dont match")
+                self.assert_(c.get_params().values() == checkparams, "params dont match ")
             else:
-                self.assert_(c.get_params() == checkparams, "params dont match")
+                self.assert_(c.get_params() == checkparams, "params dont match" + repr(c.get_params()))
             
 class SelectTest(SQLTest):
     def testtableselect(self):