From: Mike Bayer Date: Tue, 16 Feb 2010 19:47:54 +0000 (+0000) Subject: - A change to the solution for [ticket:1579] - an end-user X-Git-Tag: rel_0_6beta2~174 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ebe7f7b15e03a30ae14263f16fed8a18c35eebd9;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - A change to the solution for [ticket:1579] - an end-user defined bind parameter name that directly conflicts with a column-named bind generated directly from the SET or VALUES clause of an update/insert generates a compile error. This reduces call counts and eliminates some cases where undesirable name conflicts could still occur. --- diff --git a/CHANGES b/CHANGES index 2fa8cb151e..4141c3d8e1 100644 --- a/CHANGES +++ b/CHANGES @@ -115,6 +115,13 @@ CHANGES decorator - these may also become "public" for the benefit of the compiler extension at some point. + - A change to the solution for [ticket:1579] - an end-user + defined bind parameter name that directly conflicts with + a column-named bind generated directly from the SET or + VALUES clause of an update/insert generates a compile error. + This reduces call counts and eliminates some cases where + undesirable name conflicts could still occur. + - engines - Added an optional C extension to speed up the sql layer by reimplementing RowProxy and the most common result processors. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index fbe93f6b49..187b4d26f0 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -482,10 +482,19 @@ class SQLCompiler(engine.Compiled): name = self._truncate_bindparam(bindparam) if name in self.binds: existing = self.binds[name] - if existing is not bindparam and (existing.unique or bindparam.unique): - raise exc.CompileError( - "Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key - ) + if existing is not bindparam: + if existing.unique or bindparam.unique: + raise exc.CompileError( + "Bind parameter '%s' conflicts with " + "unique bind parameter of the same name" % bindparam.key + ) + elif getattr(existing, '_is_crud', False): + raise exc.CompileError( + "Bind parameter name '%s' is reserved " + "for the VALUES or SET clause of this insert/update statement." + % bindparam.key + ) + self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string(name) @@ -696,8 +705,9 @@ class SQLCompiler(engine.Compiled): if not colparams and \ not self.dialect.supports_default_values and \ not self.dialect.supports_empty_insert: - raise exc.CompileError( - "The version of %s you are using does not support empty inserts." % self.dialect.name) + raise exc.CompileError("The version of %s you are using does " + "not support empty inserts." % + self.dialect.name) preparer = self.preparer supports_default_values = self.dialect.supports_default_values @@ -763,6 +773,14 @@ class SQLCompiler(engine.Compiled): def _create_crud_bind_param(self, col, value, required=False): bindparam = sql.bindparam(col.key, value, type_=col.type, required=required) + bindparam._is_crud = True + if col.key in self.binds: + raise exc.CompileError( + "Bind parameter name '%s' is reserved " + "for the VALUES or SET clause of this insert/update statement." + % col.key + ) + self.binds[col.key] = bindparam return self.bindparam_string(self._truncate_bindparam(bindparam)) @@ -781,20 +799,12 @@ class SQLCompiler(engine.Compiled): self.prefetch = [] self.returning = [] - # get the keys of explicitly constructed bindparam() objects - # TODO: ouch - bind_names = set(b.key for b in visitors.iterate(stmt, {}) - if b.__visit_name__ == 'bindparam') - - if stmt.parameters: - bind_names.update(stmt.parameters) - # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: return [ (c, self._create_crud_bind_param(c, None, required=True)) - for c in stmt.table.columns if c.key not in bind_names + for c in stmt.table.columns ] required = object() @@ -805,7 +815,8 @@ class SQLCompiler(engine.Compiled): parameters = {} else: parameters = dict((sql._column_as_key(key), required) - for key in self.column_keys if key not in bind_names) + for key in self.column_keys + if not stmt.parameters or key not in stmt.parameters) if stmt.parameters is not None: for k, v in stmt.parameters.iteritems(): diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py index 0232ae15db..ffd69f97d3 100644 --- a/test/aaa_profiling/test_compiler.py +++ b/test/aaa_profiling/test_compiler.py @@ -23,6 +23,10 @@ class CompileTest(TestBase, AssertsExecutionResults): def test_update(self): t1.update().compile() + @profiling.function_call_count(128, {'2.4': 90}) + def test_update_whereclause(self): + t1.update().where(t1.c.c2==12).compile() + @profiling.function_call_count(195, versions={'2.4':118, '3.0':208, '3.1':208}) def test_select(self): s = select([t1], t1.c.c2==t2.c.c1) diff --git a/test/sql/test_select.py b/test/sql/test_select.py index 657509d659..33bbe5ff43 100644 --- a/test/sql/test_select.py +++ b/test/sql/test_select.py @@ -1721,32 +1721,57 @@ class CRUDTest(TestBase, AssertsCompiledSQL): self.assert_compile(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)") def test_binds_that_match_columns(self): - """test bind params named after column names replace the normal SET/VALUES generation.""" + """test bind params named after column names + replace the normal SET/VALUES generation.""" t = table('foo', column('x'), column('y')) u = t.update().where(t.c.x==bindparam('x')) + + assert_raises(exc.CompileError, u.compile) - self.assert_compile(u, "UPDATE foo SET y=:y WHERE foo.x = :x") self.assert_compile(u, "UPDATE foo SET WHERE foo.x = :x", params={}) - self.assert_compile(u.values(x=7), "UPDATE foo SET x=:x WHERE foo.x = :x") + + assert_raises(exc.CompileError, u.values(x=7).compile) + self.assert_compile(u.values(y=7), "UPDATE foo SET y=:y WHERE foo.x = :x") - self.assert_compile(u.values(x=7), "UPDATE foo SET x=:x, y=:y WHERE foo.x = :x", params={'x':1, 'y':2}) - self.assert_compile(u, "UPDATE foo SET y=:y WHERE foo.x = :x", params={'x':1, 'y':2}) - self.assert_compile(u.values(x=3 + bindparam('x')), "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x") - self.assert_compile(u.values(x=3 + bindparam('x')), "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x", params={'x':1}) - self.assert_compile(u.values(x=3 + bindparam('x')), "UPDATE foo SET x=(:param_1 + :x), y=:y WHERE foo.x = :x", params={'x':1, 'y':2}) + assert_raises(exc.CompileError, u.values(x=7).compile, column_keys=['x', 'y']) + assert_raises(exc.CompileError, u.compile, column_keys=['x', 'y']) + + self.assert_compile(u.values(x=3 + bindparam('x')), + "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x") + + self.assert_compile(u.values(x=3 + bindparam('x')), + "UPDATE foo SET x=(:param_1 + :x) WHERE foo.x = :x", + params={'x':1}) + + self.assert_compile(u.values(x=3 + bindparam('x')), + "UPDATE foo SET x=(:param_1 + :x), y=:y WHERE foo.x = :x", + params={'x':1, 'y':2}) i = t.insert().values(x=3 + bindparam('x')) self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x))") - self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x), :y)", params={'x':1, 'y':2}) + self.assert_compile(i, + "INSERT INTO foo (x, y) VALUES ((:param_1 + :x), :y)", + params={'x':1, 'y':2}) + + i = t.insert().values(x=bindparam('y')) + self.assert_compile(i, "INSERT INTO foo (x) VALUES (:y)") + i = t.insert().values(x=bindparam('y'), y=5) + assert_raises(exc.CompileError, i.compile) + + i = t.insert().values(x=3 + bindparam('y'), y=5) + assert_raises(exc.CompileError, i.compile) + i = t.insert().values(x=3 + bindparam('x2')) self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x2))") self.assert_compile(i, "INSERT INTO foo (x) VALUES ((:param_1 + :x2))", params={}) - self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)", params={'x':1, 'y':2}) - self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)", params={'x2':1, 'y':2}) + self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)", + params={'x':1, 'y':2}) + self.assert_compile(i, "INSERT INTO foo (x, y) VALUES ((:param_1 + :x2), :y)", + params={'x2':1, 'y':2}) class InlineDefaultTest(TestBase, AssertsCompiledSQL): def test_insert(self):