]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the "where" criterion of an update() and delete() now correlates
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 18 Apr 2007 22:54:40 +0000 (22:54 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 18 Apr 2007 22:54:40 +0000 (22:54 +0000)
embedded select() statements against the table being updated or
deleted.  this works the same as nested select() statement
correlation, and can be disabled via the correlate=False flag on
the embedded select().

CHANGES
lib/sqlalchemy/sql.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index c560f4fe71a713313c42c4b138dc6dceeaddfc48..3c600c1943f8165d4e86095a6d045fd987d1d4c5 100644 (file)
--- a/CHANGES
+++ b/CHANGES
       of unicode situations that occur in db's such as MS-SQL to be
       better handled and allows subclassing of the Unicode datatype.
       [ticket:522]
+    - the "where" criterion of an update() and delete() now correlates
+      embedded select() statements against the table being updated or
+      deleted.  this works the same as nested select() statement
+      correlation, and can be disabled via the correlate=False flag on 
+      the embedded select().
     - column labels are now generated in the compilation phase, which
       means their lengths are dialect-dependent.  So on oracle a label
       that gets truncated to 30 chars will go out to 63 characters
index 94b618491c85255b75128776a5e13ac028c93c7f..a8663ed4c3a473cc827b326b18d163807c4e1c72 100644 (file)
@@ -2301,6 +2301,7 @@ class Select(_SelectBaseMixin, FromClause):
                  use_labels=False, distinct=False, for_update=False,
                  engine=None, limit=None, offset=None, scalar=False,
                  correlate=True):
+        # TODO: docstring ! 
         _SelectBaseMixin.__init__(self)
         self.__froms = util.OrderedSet()
         self.__hide_froms = util.Set([self])
@@ -2319,7 +2320,7 @@ class Select(_SelectBaseMixin, FromClause):
         self.is_scalar = scalar
 
         # indicates if this select statement, as a subquery, should automatically correlate
-        # its FROM clause to that of an enclosing select statement.
+        # its FROM clause to that of an enclosing select, update, or delete statement.
         # note that the "correlate" method can be used to explicitly add a value to be correlated.
         self.should_correlate = correlate
 
@@ -2560,6 +2561,20 @@ class _UpdateBase(ClauseElement):
     def supports_execution(self):
         return True
 
+    class _SelectCorrelator(NoColumnVisitor):
+        def __init__(self, table):
+            NoColumnVisitor.__init__(self)
+            self.table = table
+            
+        def visit_select(self, select):
+            if select.should_correlate:
+                select.correlate(self.table)
+    
+    def _process_whereclause(self, whereclause):
+        if whereclause is not None:
+            _UpdateBase._SelectCorrelator(self.table).traverse(whereclause)
+        return whereclause
+        
     def _process_colparams(self, parameters):
         """Receive the *values* of an ``INSERT`` or ``UPDATE``
         statement and construct appropriate bind parameters.
@@ -2576,10 +2591,11 @@ class _UpdateBase(ClauseElement):
                 i +=1
             parameters = pp
 
+        correlator = _UpdateBase._SelectCorrelator(self.table)
         for key in parameters.keys():
             value = parameters[key]
-            if isinstance(value, Select):
-                value.correlate(self.table)
+            if isinstance(value, ClauseElement):
+                correlator.traverse(value)
             elif _is_literal(value):
                 if _is_literal(key):
                     col = self.table.c[key]
@@ -2611,7 +2627,7 @@ class _Insert(_UpdateBase):
 class _Update(_UpdateBase):
     def __init__(self, table, whereclause, values=None):
         self.table = table
-        self.whereclause = whereclause
+        self.whereclause = self._process_whereclause(whereclause)
         self.parameters = self._process_colparams(values)
 
     def get_children(self, **kwargs):
@@ -2625,7 +2641,7 @@ class _Update(_UpdateBase):
 class _Delete(_UpdateBase):
     def __init__(self, table, whereclause):
         self.table = table
-        self.whereclause = whereclause
+        self.whereclause = self._process_whereclause(whereclause)
 
     def get_children(self, **kwargs):
         if self.whereclause is not None:
index 91b293cbe6737730a96075b287896a7794e6e886..1d0a63e2f64a0e40f96461e4e1443cf9ba09f9b6 100644 (file)
@@ -828,10 +828,15 @@ class CRUDTest(SQLTest):
         u = update(table1, table1.c.name == 'jack', values = {table1.c.name : s})
         self.runtest(u, "UPDATE mytable SET name=(SELECT myothertable.otherid, myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid) WHERE mytable.name = :mytable_name")
 
-        # test a correlated WHERE clause
+        # test a non-correlated WHERE clause
         s = select([table2.c.othername], table2.c.otherid == 7)
         u = update(table1, table1.c.name==s)
         self.runtest(u, "UPDATE mytable SET myid=:myid, name=:name, description=:description WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = :myothertable_otherid)")
+
+        # test one that is actually correlated...
+        s = select([table2.c.othername], table2.c.otherid == table1.c.myid)
+        u = table1.update(table1.c.name==s)
+        self.runtest(u, "UPDATE mytable SET myid=:myid, name=:name, description=:description WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)")
         
     def testdelete(self):
         self.runtest(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid")