]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 23 Jul 2005 20:06:57 +0000 (20:06 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 23 Jul 2005 20:06:57 +0000 (20:06 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/sql.py

index 0d93ea5188c3483e82b9e19fbf0dee56ed583030..3f6cbb835df6d699732b9d88580fffc01e2515eb 100644 (file)
@@ -55,7 +55,7 @@ class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
 class ANSICompiler(sql.Compiled):
     def __init__(self, parent, bindparams):
         self.binds = {}
-        self.bindparams = bindparams
+        self._bindparams = bindparams
         self.parent = parent
         self.froms = {}
         self.wheres = {}
@@ -71,6 +71,8 @@ class ANSICompiler(sql.Compiled):
         return self.wheres.get(obj, None)
         
     def get_params(self, **params):
+        """returns the bind params for this compiled object, with values overridden by 
+        those given in the **params dictionary"""
         d = {}
         for key, value in params.iteritems():
             try:
@@ -80,8 +82,7 @@ class ANSICompiler(sql.Compiled):
             d[b.key] = value
 
         for b in self.binds.values():
-            if not d.has_key(b.key):
-                d[b.key] = b.value
+            d.setdefault(b.key, b.value)
 
         return d
         
@@ -166,7 +167,7 @@ class ANSICompiler(sql.Compiled):
             if t is not None:
                 froms.append(t)
 
-        text += string.join(froms, ', ')                
+        text += string.join(froms, ', ')
 
         if whereclause is not None:
             t = self.get_str(whereclause)
@@ -182,18 +183,17 @@ class ANSICompiler(sql.Compiled):
 
     def visit_table(self, table):
         self.froms[table] = table.name
-        
+
     def visit_join(self, join):
         if join.isouter:
             self.froms[join] = (self.get_from_text(join.left) + " LEFT OUTER JOIN " + self.get_from_text(join.right) + 
             " ON " + self.get_str(join.onclause))
         else:
-            self.froms[join] = (self.get_from_text(join.left) + " JOIN " + self.get_from_text(join.right) + 
+            self.froms[join] = (self.get_from_text(join.left) + " JOIN " + self.get_from_text(join.right) +
             " ON " + self.get_str(join.onclause))
-            
-            
+
     def visit_insert(self, insert_stmt):
-        colparams = insert_stmt.get_colparams(self.bindparams)
+        colparams = insert_stmt.get_colparams(self._bindparams)
 
         for c in colparams:
             b = c[1]
@@ -206,7 +206,7 @@ class ANSICompiler(sql.Compiled):
         self.strings[insert_stmt] = text
 
     def visit_update(self, update_stmt):
-        colparams = update_stmt.get_colparams(self.bindparams)
+        colparams = update_stmt.get_colparams(self._bindparams)
         
         for c in colparams:
             b = c[1]
index 6a1b58da9335c666b32293765977729a985833a7..fffda0916daab4f3c2ed5ebc9bdd55901e7c86b2 100644 (file)
@@ -51,6 +51,12 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine):
     
     def connect_args(self):
         return ([self.filename], self.opts)
+
+    def compile(self, statement, bindparams):
+        compiler = SQLiteCompiler(statement, bindparams)
+
+        statement.accept_visitor(compiler)
+        return compiler
         
     def dbapi(self):
         return sqlite
@@ -61,6 +67,10 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine):
     def reflecttable(self, table):
         raise NotImplementedError()
 
+class SQLiteCompiler(ansisql.ANSICompiler):
+    def visit_insert(self, insert):
+        ansisql.ANSICompiler.visit_insert(self, insert)
+        
 class SQLiteColumnImpl(sql.ColumnSelectable):
     def _get_specification(self):
         coltype = self.column.type
index 29de0d34fcc48a6659d78715f0578d5bc44df5ef..8f3e51fbd9ebd2307cadf2ba42f798fc6b5cbed8 100644 (file)
@@ -136,13 +136,12 @@ class ClauseElement(object):
         c = self.compile(e, bindparams = params)
         # TODO: do pre-execute right here, for sequences, if the compiled object
         # defines it
-        # TODO: why do we send the params twice, once to compile, once to c.get_params
-        return e.execute(str(c), c.get_params(**params), echo = getattr(self, 'echo', None))
+        return e.execute(str(c), c.get_params(), echo = getattr(self, 'echo', None))
 
     def result(self, **params):
         e = self._engine()
         c = self.compile(e, bindparams = params)
-        return e.result(str(c), c.get_params(**params))
+        return e.result(str(c), c.binds)
         
 class ColumnClause(ClauseElement):
     """represents a column clause element in a SQL statement."""