]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- insert() and update() constructs can now embed bindparam()
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Oct 2009 01:08:02 +0000 (01:08 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Oct 2009 01:08:02 +0000 (01:08 +0000)
      objects using names that match the keys of columns.  These
      bind parameters will circumvent the usual route to those
      keys showing up in the VALUES or SET clause of the generated
      SQL. [ticket:1579]

CHANGES
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/sql/compiler.py
test/sql/test_select.py

diff --git a/CHANGES b/CHANGES
index 86baaa78a48f6cad2723e877b4a885f8ddb8afd7..baa34e7896b16e36e5794f4b1c44e0921e325b75 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -191,6 +191,12 @@ CHANGES
       performed). This occurs if no end-user returning() was
       specified.
 
+    - insert() and update() constructs can now embed bindparam()
+      objects using names that match the keys of columns.  These
+      bind parameters will circumvent the usual route to those 
+      keys showing up in the VALUES or SET clause of the generated
+      SQL. [ticket:1579]
+      
     - Databases which rely upon postfetch of "last inserted id"
       to get at a generated sequence value (i.e. MySQL, MS-SQL)
       now work correctly when there is a composite primary key
index 8af28c8e53f6b3f7c02b413f5edd224fb5cd9a68..689b518f1d6aa3c76c35ce5a8e3c5227af0976d8 100644 (file)
@@ -465,13 +465,13 @@ class OracleDDLCompiler(compiler.DDLCompiler):
 class OracleIdentifierPreparer(compiler.IdentifierPreparer):
     
     reserved_words = set([x.lower() for x in RESERVED_WORDS])
-    illegal_initial_characters = re.compile(r'[0-9_$]')
+    illegal_initial_characters = set(xrange(0, 10)).union(["_", "$"])
 
     def _bindparam_requires_quotes(self, value):
         """Return True if the given identifier requires quoting."""
         lc_value = value.lower()
         return (lc_value in self.reserved_words
-                or self.illegal_initial_characters.match(value[0])
+                or value[0] in self.illegal_initial_characters
                 or not self.legal_characters.match(unicode(value))
                 )
     
index 4c31308796e27512c7c473fb80712e0dffcd5d1f..5f5b31c68a37b551bc46202806940dd8482541da 100644 (file)
@@ -47,7 +47,7 @@ RESERVED_WORDS = set([
     'using', 'verbose', 'when', 'where'])
 
 LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I)
-ILLEGAL_INITIAL_CHARACTERS = re.compile(r'[0-9$]')
+ILLEGAL_INITIAL_CHARACTERS = set(xrange(0, 10)).union(['$'])
 
 BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE)
 BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]+)(?![:\w\$])', re.UNICODE)
@@ -776,12 +776,17 @@ class SQLCompiler(engine.Compiled):
         self.prefetch = []
         self.returning = []
 
+        # get the keys of explicitly constructed bindparam() objects
+        bind_names = set(b.key for b in visitors.iterate(stmt, {}) if b.__visit_name__ == 'bindparam')
+        if stmt.parameters:
+            bind_names.update(stmt.parameters)
+
         # no parameters in the statement, no parameters in the
         # compiled params - return binds for all columns
         if self.column_keys is None and stmt.parameters is None:
             return [
                         (c, self._create_crud_bind_param(c, None, required=True)) 
-                        for c in stmt.table.columns
+                        for c in stmt.table.columns if c.key not in bind_names
                     ]
 
         required = object()
@@ -792,7 +797,7 @@ class SQLCompiler(engine.Compiled):
             parameters = {}
         else:
             parameters = dict((sql._column_as_key(key), required)
-                              for key in self.column_keys)
+                              for key in self.column_keys if key not in bind_names)
 
         if stmt.parameters is not None:
             for k, v in stmt.parameters.iteritems():
@@ -1312,7 +1317,7 @@ class IdentifierPreparer(object):
         """Return True if the given identifier requires quoting."""
         lc_value = value.lower()
         return (lc_value in self.reserved_words
-                or self.illegal_initial_characters.match(value[0])
+                or value[0] in self.illegal_initial_characters
                 or not self.legal_characters.match(unicode(value))
                 or (lc_value != value))
 
index 3dc09c9dfc1bb33e5c6c8ba265d5e41f79978a2d..1db2559bc6ca2df89b533bd533987d282786f195 100644 (file)
@@ -1574,7 +1574,35 @@ class CRUDTest(TestBase, AssertsCompiledSQL):
         s = select([table2.c.othername], table2.c.otherid == table1.c.myid)
         u = table1.delete(table1.c.name==s)
         self.assert_compile(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)")
+    
+    def test_binds_that_match_columns(self):
+        """test bind params named after column names replace the normal SET/VALUES generation."""
+        
+        t = table('foo', column('x'), column('y'))
 
+        u = t.update().where(t.c.x==bindparam('x'))
+        
+        self.assert_compile(u, "UPDATE foo SET y=:y WHERE foo.x = :x")
+        self.assert_compile(u, "UPDATE foo SET  WHERE foo.x = :x", params={})
+        self.assert_compile(u.values(x=7), "UPDATE foo SET x=:x WHERE foo.x = :x")
+        self.assert_compile(u.values(y=7), "UPDATE foo SET y=:y WHERE foo.x = :x")
+        self.assert_compile(u.values(x=7), "UPDATE foo SET x=:x, y=:y WHERE foo.x = :x", params={'x':1, 'y':2})
+        self.assert_compile(u, "UPDATE foo SET y=:y WHERE foo.x = :x", params={'x':1, 'y':2})
+        
+        self.assert_compile(u.values(x=3 + bindparam('x')), "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x")
+        self.assert_compile(u.values(x=3 + bindparam('x')), "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x", params={'x':1})
+        self.assert_compile(u.values(x=3 + bindparam('x')), "UPDATE foo SET x=(:param_1 + :x), y=:y WHERE foo.x = :x", params={'x':1, 'y':2})
+
+        i = t.insert().values(x=3 + bindparam('x'))
+        self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x))")
+        self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x), :y)", params={'x':1, 'y':2})
+
+        i = t.insert().values(x=3 + bindparam('x2'))
+        self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x2))")
+        self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x2))", params={})
+        self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)", params={'x':1, 'y':2})
+        self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)", params={'x2':1, 'y':2})
+        
 class InlineDefaultTest(TestBase, AssertsCompiledSQL):
     def test_insert(self):
         m = MetaData()