]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added generative where(<criterion>) method to delete()
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 8 Feb 2008 22:57:45 +0000 (22:57 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 8 Feb 2008 22:57:45 +0000 (22:57 +0000)
and update() constructs which return a new object with
criterion joined to existing criterion via AND, just
like select().where().
- compile assertions use assertEquals()

CHANGES
lib/sqlalchemy/sql/expression.py
test/sql/select.py
test/testlib/testing.py

diff --git a/CHANGES b/CHANGES
index 50df3a03031d841f77bc1c3c52b9f7ffb39ff2d0..605569ead152e75a04384f71b8b743b51172c58b 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -13,6 +13,11 @@ CHANGES
     - Added a callable-based DDL events interface, adds hooks
       before and after Tables and MetaData create and drop.
 
+    - added generative where(<criterion>) method to delete() 
+      and update() constructs which return a new object with
+      criterion joined to existing criterion via AND, just
+      like select().where().
+      
     - Added "ilike()" operator to column operations.  Compiles to
       ILIKE on postgres, lower(x) LIKE lower(y) on all
       others. [ticket:727]
index 5c8008f3d02a1d13feb472d0705fb1be8b60bc14..79eb1759d3d01baa4f7daf31eefa71aa56ae5ff0 100644 (file)
@@ -2178,6 +2178,9 @@ class _Exists(_UnaryExpression):
         return e
 
     def where(self, clause):
+        """return a new exists() construct with the given expression added to its WHERE clause, joined
+        to the existing clause via AND, if any."""
+
         e = self._clone()
         e.element = self.element.where(clause).self_group()
         return e
@@ -3493,7 +3496,10 @@ class Insert(_UpdateBase):
 class Update(_UpdateBase):
     def __init__(self, table, whereclause, values=None, inline=False, **kwargs):
         self.table = table
-        self._whereclause = whereclause
+        if whereclause:
+            self._whereclause = _literal_as_text(whereclause)
+        else:
+            self._whereclause = None
         self.inline = inline
         self.parameters = self._process_colparams(values)
 
@@ -3509,6 +3515,17 @@ class Update(_UpdateBase):
         self._whereclause = clone(self._whereclause)
         self.parameters = self.parameters.copy()
 
+    def where(self, whereclause):
+        """return a new update() construct with the given expression added to its WHERE clause, joined
+        to the existing clause via AND, if any."""
+        
+        s = self._clone()
+        if s._whereclause is not None:
+            s._whereclause = and_(s._whereclause, _literal_as_text(whereclause))
+        else:
+            s._whereclause = _literal_as_text(whereclause)
+        return s
+
     def values(self, v):
         if len(v) == 0:
             return self
@@ -3523,7 +3540,10 @@ class Update(_UpdateBase):
 class Delete(_UpdateBase):
     def __init__(self, table, whereclause):
         self.table = table
-        self._whereclause = whereclause
+        if whereclause:
+            self._whereclause = _literal_as_text(whereclause)
+        else:
+            self._whereclause = None
 
     def get_children(self, **kwargs):
         if self._whereclause is not None:
@@ -3531,6 +3551,17 @@ class Delete(_UpdateBase):
         else:
             return ()
 
+    def where(self, whereclause):
+        """return a new delete() construct with the given expression added to its WHERE clause, joined
+        to the existing clause via AND, if any."""
+        
+        s = self._clone()
+        if s._whereclause is not None:
+            s._whereclause = and_(s._whereclause, _literal_as_text(whereclause))
+        else:
+            s._whereclause = _literal_as_text(whereclause)
+        return s
+        
     def _copy_internals(self, clone=_clone):
         self._whereclause = clone(self._whereclause)
 
index 69d39ede65f63fc7dda3acdfc46f03b3393f3bea..54b1e87a2c57352041eb7c8f6d17a0fb3429ce86 100644 (file)
@@ -1240,6 +1240,7 @@ class CRUDTest(SQLCompileTest):
 
     def test_update(self):
         self.assert_compile(update(table1, table1.c.myid == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid_1", params = {table1.c.name:'fred'})
+        self.assert_compile(table1.update().where(table1.c.myid==7).values({table1.c.myid:5}), "UPDATE mytable SET myid=:myid WHERE mytable.myid = :mytable_myid_1", checkparams={'myid':5, 'mytable_myid_1':7})
         self.assert_compile(update(table1, table1.c.myid == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid_1", params = {'name':'fred'})
         self.assert_compile(update(table1, values = {table1.c.name : table1.c.myid}), "UPDATE mytable SET name=mytable.myid")
         self.assert_compile(update(table1, whereclause = table1.c.name == bindparam('crit'), values = {table1.c.name : 'hi'}), "UPDATE mytable SET name=:name WHERE mytable.name = :crit", params = {'crit' : 'notthere'}, checkparams={'crit':'notthere', 'name':'hi'})
@@ -1288,7 +1289,9 @@ class CRUDTest(SQLCompileTest):
 
     def test_delete(self):
         self.assert_compile(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid_1")
-
+        self.assert_compile(table1.delete().where(table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid_1")
+        self.assert_compile(table1.delete().where(table1.c.myid == 7).where(table1.c.name=='somename'), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid_1 AND mytable.name = :mytable_name_1")
+        
     def test_correlated_delete(self):
         # test a non-correlated WHERE clause
         s = select([table2.c.othername], table2.c.otherid == 7)
index 32e5e9b4a188e2975b34c8fc4b0edecef81975b8..3b7b2992e28341eb8f4d646a9dfb63d09084e39c 100644 (file)
@@ -480,10 +480,10 @@ class SQLCompileTest(PersistTest):
 
         cc = re.sub(r'\n', '', str(c))
 
-        self.assert_(cc == result, "\n'" + cc + "'\n does not match \n'" + result + "'")
+        self.assertEquals(cc, result)
 
         if checkparams is not None:
-            self.assert_(c.construct_params(params) == checkparams, "params dont match" + repr(c.params))
+            self.assertEquals(c.construct_params(params), checkparams)
 
 class AssertMixin(PersistTest):
     """given a list-based structure of keys/properties which represent information within an object structure, and