]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- A change to the solution for [ticket:1579] - an end-user
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 16 Feb 2010 19:47:54 +0000 (19:47 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 16 Feb 2010 19:47:54 +0000 (19:47 +0000)
defined bind parameter name that directly conflicts with
a column-named bind generated directly from the SET or
VALUES clause of an update/insert generates a compile error.
This reduces call counts and eliminates some cases where
undesirable name conflicts could still occur.

CHANGES
lib/sqlalchemy/sql/compiler.py
test/aaa_profiling/test_compiler.py
test/sql/test_select.py

diff --git a/CHANGES b/CHANGES
index 2fa8cb151ef085c77e34ca31c88f6fad78fc3301..4141c3d8e1926bdb0d2a3027054f3de47393c8d6 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -115,6 +115,13 @@ CHANGES
     decorator - these may also become "public" for the benefit
     of the compiler extension at some point.
 
+  - A change to the solution for [ticket:1579] - an end-user
+    defined bind parameter name that directly conflicts with 
+    a column-named bind generated directly from the SET or
+    VALUES clause of an update/insert generates a compile error.
+    This reduces call counts and eliminates some cases where
+    undesirable name conflicts could still occur.
+    
 - engines
   - Added an optional C extension to speed up the sql layer by
     reimplementing RowProxy and the most common result processors.
index fbe93f6b49dc9215489f1180888a1b8fff268a67..187b4d26f02561998f487bc527bf588643325af7 100644 (file)
@@ -482,10 +482,19 @@ class SQLCompiler(engine.Compiled):
         name = self._truncate_bindparam(bindparam)
         if name in self.binds:
             existing = self.binds[name]
-            if existing is not bindparam and (existing.unique or bindparam.unique):
-                raise exc.CompileError(
-                        "Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key
-                    )
+            if existing is not bindparam:
+                if existing.unique or bindparam.unique:
+                    raise exc.CompileError(
+                            "Bind parameter '%s' conflicts with "
+                            "unique bind parameter of the same name" % bindparam.key
+                        )
+                elif getattr(existing, '_is_crud', False):
+                    raise exc.CompileError(
+                            "Bind parameter name '%s' is reserved "
+                            "for the VALUES or SET clause of this insert/update statement." 
+                            % bindparam.key
+                        )
+                    
         self.binds[bindparam.key] = self.binds[name] = bindparam
         return self.bindparam_string(name)
 
@@ -696,8 +705,9 @@ class SQLCompiler(engine.Compiled):
         if not colparams and \
                 not self.dialect.supports_default_values and \
                 not self.dialect.supports_empty_insert:
-            raise exc.CompileError(
-                "The version of %s you are using does not support empty inserts." % self.dialect.name)
+            raise exc.CompileError("The version of %s you are using does "
+                                    "not support empty inserts." % 
+                                    self.dialect.name)
 
         preparer = self.preparer
         supports_default_values = self.dialect.supports_default_values
@@ -763,6 +773,14 @@ class SQLCompiler(engine.Compiled):
 
     def _create_crud_bind_param(self, col, value, required=False):
         bindparam = sql.bindparam(col.key, value, type_=col.type, required=required)
+        bindparam._is_crud = True
+        if col.key in self.binds:
+            raise exc.CompileError(
+                    "Bind parameter name '%s' is reserved "
+                    "for the VALUES or SET clause of this insert/update statement." 
+                    % col.key
+                )
+            
         self.binds[col.key] = bindparam
         return self.bindparam_string(self._truncate_bindparam(bindparam))
         
@@ -781,20 +799,12 @@ class SQLCompiler(engine.Compiled):
         self.prefetch = []
         self.returning = []
 
-        # get the keys of explicitly constructed bindparam() objects
-        # TODO: ouch
-        bind_names = set(b.key for b in visitors.iterate(stmt, {}) 
-                            if b.__visit_name__ == 'bindparam')
-        
-        if stmt.parameters:
-            bind_names.update(stmt.parameters)
-
         # no parameters in the statement, no parameters in the
         # compiled params - return binds for all columns
         if self.column_keys is None and stmt.parameters is None:
             return [
                         (c, self._create_crud_bind_param(c, None, required=True)) 
-                        for c in stmt.table.columns if c.key not in bind_names
+                        for c in stmt.table.columns
                     ]
 
         required = object()
@@ -805,7 +815,8 @@ class SQLCompiler(engine.Compiled):
             parameters = {}
         else:
             parameters = dict((sql._column_as_key(key), required)
-                              for key in self.column_keys if key not in bind_names)
+                              for key in self.column_keys 
+                              if not stmt.parameters or key not in stmt.parameters)
 
         if stmt.parameters is not None:
             for k, v in stmt.parameters.iteritems():
index 0232ae15db126d0db90e3cdc0a2e9b5082dca9a6..ffd69f97d31f6d914e2a63a227df98b1d5ef8e56 100644 (file)
@@ -23,6 +23,10 @@ class CompileTest(TestBase, AssertsExecutionResults):
     def test_update(self):
         t1.update().compile()
 
+    @profiling.function_call_count(128, {'2.4': 90})
+    def test_update_whereclause(self):
+        t1.update().where(t1.c.c2==12).compile()
+
     @profiling.function_call_count(195, versions={'2.4':118, '3.0':208, '3.1':208})
     def test_select(self):
         s = select([t1], t1.c.c2==t2.c.c1)
index 657509d6591697130511efccf473ddbc61dd9e5b..33bbe5ff43b94f9560be784425240a9c54856f7e 100644 (file)
@@ -1721,32 +1721,57 @@ class CRUDTest(TestBase, AssertsCompiledSQL):
         self.assert_compile(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)")
     
     def test_binds_that_match_columns(self):
-        """test bind params named after column names replace the normal SET/VALUES generation."""
+        """test bind params named after column names 
+        replace the normal SET/VALUES generation."""
         
         t = table('foo', column('x'), column('y'))
 
         u = t.update().where(t.c.x==bindparam('x'))
+    
+        assert_raises(exc.CompileError, u.compile)
         
-        self.assert_compile(u, "UPDATE foo SET y=:y WHERE foo.x = :x")
         self.assert_compile(u, "UPDATE foo SET  WHERE foo.x = :x", params={})
-        self.assert_compile(u.values(x=7), "UPDATE foo SET x=:x WHERE foo.x = :x")
+
+        assert_raises(exc.CompileError, u.values(x=7).compile)
+        
         self.assert_compile(u.values(y=7), "UPDATE foo SET y=:y WHERE foo.x = :x")
-        self.assert_compile(u.values(x=7), "UPDATE foo SET x=:x, y=:y WHERE foo.x = :x", params={'x':1, 'y':2})
-        self.assert_compile(u, "UPDATE foo SET y=:y WHERE foo.x = :x", params={'x':1, 'y':2})
         
-        self.assert_compile(u.values(x=3 + bindparam('x')), "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x")
-        self.assert_compile(u.values(x=3 + bindparam('x')), "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x", params={'x':1})
-        self.assert_compile(u.values(x=3 + bindparam('x')), "UPDATE foo SET x=(:param_1 + :x), y=:y WHERE foo.x = :x", params={'x':1, 'y':2})
+        assert_raises(exc.CompileError, u.values(x=7).compile, column_keys=['x', 'y'])
+        assert_raises(exc.CompileError, u.compile, column_keys=['x', 'y'])
+        
+        self.assert_compile(u.values(x=3 + bindparam('x')), 
+                            "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x")
+
+        self.assert_compile(u.values(x=3 + bindparam('x')), 
+                            "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x",
+                            params={'x':1})
+
+        self.assert_compile(u.values(x=3 + bindparam('x')), 
+                            "UPDATE foo SET x=(:param_1 + :x), y=:y WHERE foo.x = :x",
+                            params={'x':1, 'y':2})
 
         i = t.insert().values(x=3 + bindparam('x'))
         self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x))")
-        self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x), :y)", params={'x':1, 'y':2})
+        self.assert_compile(i, 
+                            "INSERT INTO foo (x, y) VALUES ((:param_1 + :x), :y)",
+                            params={'x':1, 'y':2})
+
+        i = t.insert().values(x=bindparam('y'))
+        self.assert_compile(i, "INSERT INTO foo (x) VALUES (:y)")
 
+        i = t.insert().values(x=bindparam('y'), y=5)
+        assert_raises(exc.CompileError, i.compile)
+
+        i = t.insert().values(x=3 + bindparam('y'), y=5)
+        assert_raises(exc.CompileError, i.compile)
+        
         i = t.insert().values(x=3 + bindparam('x2'))
         self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x2))")
         self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x2))", params={})
-        self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)", params={'x':1, 'y':2})
-        self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)", params={'x2':1, 'y':2})
+        self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)",
+                                    params={'x':1, 'y':2})
+        self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)",
+                                    params={'x2':1, 'y':2})
         
 class InlineDefaultTest(TestBase, AssertsCompiledSQL):
     def test_insert(self):