]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix for [ticket:169], moves the creation of "default" parameters more accurately
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 28 Apr 2006 23:31:59 +0000 (23:31 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 28 Apr 2006 23:31:59 +0000 (23:31 +0000)
where theyre supposed to be

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/postgres.py
test/defaults.py

index 7aefea0ba2e49d221338b4c3c11ca8e8d3f13b15..a344e017c77bdc2a59e6c52584d787f5c3b3edc9 100644 (file)
@@ -146,7 +146,6 @@ class ANSICompiler(sql.Compiled):
                 continue
             d.set_parameter(b.key, value, b)
 
-        #print "FROM", params, "TO", d
         return d
 
     def get_named_params(self, parameters):
@@ -425,26 +424,26 @@ class ANSICompiler(sql.Compiled):
             " ON " + self.get_str(join.onclause))
         self.strings[join] = self.froms[join]
 
-    def visit_insert_column_default(self, column, default):
+    def visit_insert_column_default(self, column, default, parameters):
         """called when visiting an Insert statement, for each column in the table that
         contains a ColumnDefault object.  adds a blank 'placeholder' parameter so the 
         Insert gets compiled with this column's name in its column and VALUES clauses."""
-        self.parameters.setdefault(column.key, None)
+        parameters.setdefault(column.key, None)
 
-    def visit_update_column_default(self, column, default):
+    def visit_update_column_default(self, column, default, parameters):
         """called when visiting an Update statement, for each column in the table that
         contains a ColumnDefault object as an onupdate. adds a blank 'placeholder' parameter so the 
         Update gets compiled with this column's name as one of its SET clauses."""
-        self.parameters.setdefault(column.key, None)
+        parameters.setdefault(column.key, None)
         
-    def visit_insert_sequence(self, column, sequence):
+    def visit_insert_sequence(self, column, sequence, parameters):
         """called when visiting an Insert statement, for each column in the table that
         contains a Sequence object.  Overridden by compilers that support sequences to place
         a blank 'placeholder' parameter, so the Insert gets compiled with this column's
         name in its column and VALUES clauses."""
         pass
     
-    def visit_insert_column(self, column):
+    def visit_insert_column(self, column, parameters):
         """called when visiting an Insert statement, for each column in the table
         that is a NULL insert into the table.  Overridden by compilers who disallow
         NULL columns being set in an Insert where there is a default value on the column
@@ -454,25 +453,27 @@ class ANSICompiler(sql.Compiled):
     def visit_insert(self, insert_stmt):
         # scan the table's columns for defaults that have to be pre-set for an INSERT
         # add these columns to the parameter list via visit_insert_XXX methods
+        default_params = {}
         class DefaultVisitor(schema.SchemaVisitor):
             def visit_column(s, c):
-                self.visit_insert_column(c)
+                self.visit_insert_column(c, default_params)
             def visit_column_default(s, cd):
-                self.visit_insert_column_default(c, cd)
+                self.visit_insert_column_default(c, cd, default_params)
             def visit_sequence(s, seq):
-                self.visit_insert_sequence(c, seq)
+                self.visit_insert_sequence(c, seq, default_params)
         vis = DefaultVisitor()
         for c in insert_stmt.table.c:
             if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
                 c.accept_schema_visitor(vis)
         
         self.isinsert = True
-        colparams = self._get_colparams(insert_stmt)
+        colparams = self._get_colparams(insert_stmt, default_params)
 
         def create_param(p):
             if isinstance(p, sql.BindParamClause):
                 self.binds[p.key] = p
-                self.binds[p.shortname] = p
+                if p.shortname is not None:
+                    self.binds[p.shortname] = p
                 return self.bindparam_string(p.key)
             else:
                 p.accept_visitor(self)
@@ -483,22 +484,23 @@ class ANSICompiler(sql.Compiled):
 
         text = ("INSERT INTO " + insert_stmt.table.fullname + " (" + string.join([c[0].name for c in colparams], ', ') + ")" +
          " VALUES (" + string.join([create_param(c[1]) for c in colparams], ', ') + ")")
-         
+
         self.strings[insert_stmt] = text
 
     def visit_update(self, update_stmt):
         # scan the table's columns for onupdates that have to be pre-set for an UPDATE
         # add these columns to the parameter list via visit_update_XXX methods
+        default_params = {}
         class OnUpdateVisitor(schema.SchemaVisitor):
             def visit_column_onupdate(s, cd):
-                self.visit_update_column_default(c, cd)
+                self.visit_update_column_default(c, cd, default_params)
         vis = OnUpdateVisitor()
         for c in update_stmt.table.c:
             if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
                 c.accept_schema_visitor(vis)
 
         self.isupdate = True
-        colparams = self._get_colparams(update_stmt)
+        colparams = self._get_colparams(update_stmt, default_params)
         def create_param(p):
             if isinstance(p, sql.BindParamClause):
                 self.binds[p.key] = p
@@ -519,7 +521,7 @@ class ANSICompiler(sql.Compiled):
         self.strings[update_stmt] = text
 
 
-    def _get_colparams(self, stmt):
+    def _get_colparams(self, stmt, default_params):
         """determines the VALUES or SET clause for an INSERT or UPDATE
         clause based on the arguments specified to this ANSICompiler object
         (i.e., the execute() or compile() method clause object):
@@ -550,6 +552,9 @@ class ANSICompiler(sql.Compiled):
             for k, v in stmt.parameters.iteritems():
                 parameters.setdefault(k, v)
 
+        for k, v in default_params.iteritems():
+            parameters.setdefault(k, v)
+            
         # now go thru compiled params, get the Column object for each key
         d = {}
         for key, value in parameters.iteritems():
index 8a063ca06e62943b8b89d30204c62803d210a067..19a703c0ce41a144e99817eb458b0055a8dac9e0 100644 (file)
@@ -291,13 +291,13 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
 class PGCompiler(ansisql.ANSICompiler):
 
         
-    def visit_insert_column(self, column):
+    def visit_insert_column(self, column, parameters):
         # Postgres advises against OID usage and turns it off in 8.1,
         # effectively making cursor.lastrowid
         # useless, effectively making reliance upon SERIAL useless.  
         # so all column primary key inserts must be explicitly present
         if column.primary_key:
-            self.parameters[column.key] = None
+            parameters[column.key] = None
 
     def limit_clause(self, select):
         text = ""
index 0d91d12a4810088e982111f1a81071922cd53869..096355826fa4a9a185903f15f8af29b5365a513c 100644 (file)
@@ -92,6 +92,12 @@ class DefaultTest(PersistTest):
         l = t.select().execute()
         self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec), (52, 'imthedefault', f, ts, ts, ctexec), (53, 'imthedefault', f, ts, ts, ctexec)])
 
+    def testinsertvalues(self):
+        t.insert(values={'col3':50}).execute()
+        l = t.select().execute()
+        self.assert_(l.fetchone()['col3'] == 50)
+        
+        
     def testupdate(self):
         t.insert().execute()
         pk = t.engine.last_inserted_ids()[0]
@@ -103,6 +109,14 @@ class DefaultTest(PersistTest):
         self.assert_(l == (pk, 'im the update', f2, None, None, ctexec))
         # mysql/other db's return 0 or 1 for count(1)
         self.assert_(14 <= f2 <= 15)
+
+    def testupdatevalues(self):
+        t.insert().execute()
+        pk = t.engine.last_inserted_ids()[0]
+        t.update(t.c.col1==pk, values={'col3': 55}).execute()
+        l = t.select(t.c.col1==pk).execute()
+        l = l.fetchone()
+        self.assert_(l['col3'] == 55)
         
 class SequenceTest(PersistTest):