From 4fffc21c87cbdfc538fe2924f82bf1591823856d Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 18 Apr 2007 22:54:40 +0000 Subject: [PATCH] - the "where" criterion of an update() and delete() now correlates embedded select() statements against the table being updated or deleted. this works the same as nested select() statement correlation, and can be disabled via the correlate=False flag on the embedded select(). --- CHANGES | 5 +++++ lib/sqlalchemy/sql.py | 26 +++++++++++++++++++++----- test/sql/select.py | 7 ++++++- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/CHANGES b/CHANGES index c560f4fe71..3c600c1943 100644 --- a/CHANGES +++ b/CHANGES @@ -32,6 +32,11 @@ of unicode situations that occur in db's such as MS-SQL to be better handled and allows subclassing of the Unicode datatype. [ticket:522] + - the "where" criterion of an update() and delete() now correlates + embedded select() statements against the table being updated or + deleted. this works the same as nested select() statement + correlation, and can be disabled via the correlate=False flag on + the embedded select(). - column labels are now generated in the compilation phase, which means their lengths are dialect-dependent. So on oracle a label that gets truncated to 30 chars will go out to 63 characters diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 94b618491c..a8663ed4c3 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -2301,6 +2301,7 @@ class Select(_SelectBaseMixin, FromClause): use_labels=False, distinct=False, for_update=False, engine=None, limit=None, offset=None, scalar=False, correlate=True): + # TODO: docstring ! _SelectBaseMixin.__init__(self) self.__froms = util.OrderedSet() self.__hide_froms = util.Set([self]) @@ -2319,7 +2320,7 @@ class Select(_SelectBaseMixin, FromClause): self.is_scalar = scalar # indicates if this select statement, as a subquery, should automatically correlate - # its FROM clause to that of an enclosing select statement. + # its FROM clause to that of an enclosing select, update, or delete statement. # note that the "correlate" method can be used to explicitly add a value to be correlated. self.should_correlate = correlate @@ -2560,6 +2561,20 @@ class _UpdateBase(ClauseElement): def supports_execution(self): return True + class _SelectCorrelator(NoColumnVisitor): + def __init__(self, table): + NoColumnVisitor.__init__(self) + self.table = table + + def visit_select(self, select): + if select.should_correlate: + select.correlate(self.table) + + def _process_whereclause(self, whereclause): + if whereclause is not None: + _UpdateBase._SelectCorrelator(self.table).traverse(whereclause) + return whereclause + def _process_colparams(self, parameters): """Receive the *values* of an ``INSERT`` or ``UPDATE`` statement and construct appropriate bind parameters. @@ -2576,10 +2591,11 @@ class _UpdateBase(ClauseElement): i +=1 parameters = pp + correlator = _UpdateBase._SelectCorrelator(self.table) for key in parameters.keys(): value = parameters[key] - if isinstance(value, Select): - value.correlate(self.table) + if isinstance(value, ClauseElement): + correlator.traverse(value) elif _is_literal(value): if _is_literal(key): col = self.table.c[key] @@ -2611,7 +2627,7 @@ class _Insert(_UpdateBase): class _Update(_UpdateBase): def __init__(self, table, whereclause, values=None): self.table = table - self.whereclause = whereclause + self.whereclause = self._process_whereclause(whereclause) self.parameters = self._process_colparams(values) def get_children(self, **kwargs): @@ -2625,7 +2641,7 @@ class _Update(_UpdateBase): class _Delete(_UpdateBase): def __init__(self, table, whereclause): self.table = table - self.whereclause = whereclause + self.whereclause = self._process_whereclause(whereclause) def get_children(self, **kwargs): if self.whereclause is not None: diff --git a/test/sql/select.py b/test/sql/select.py index 91b293cbe6..1d0a63e2f6 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -828,10 +828,15 @@ class CRUDTest(SQLTest): u = update(table1, table1.c.name == 'jack', values = {table1.c.name : s}) self.runtest(u, "UPDATE mytable SET name=(SELECT myothertable.otherid, myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid) WHERE mytable.name = :mytable_name") - # test a correlated WHERE clause + # test a non-correlated WHERE clause s = select([table2.c.othername], table2.c.otherid == 7) u = update(table1, table1.c.name==s) self.runtest(u, "UPDATE mytable SET myid=:myid, name=:name, description=:description WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = :myothertable_otherid)") + + # test one that is actually correlated... + s = select([table2.c.othername], table2.c.otherid == table1.c.myid) + u = table1.update(table1.c.name==s) + self.runtest(u, "UPDATE mytable SET myid=:myid, name=:name, description=:description WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)") def testdelete(self): self.runtest(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid") -- 2.47.2