]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add support for returning results from inserts and updates for postgresql 8.2+. ...
authorAnts Aasma <ants.aasma@gmail.com>
Tue, 2 Oct 2007 23:57:54 +0000 (23:57 +0000)
committerAnts Aasma <ants.aasma@gmail.com>
Tue, 2 Oct 2007 23:57:54 +0000 (23:57 +0000)
CHANGES
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/sql/expression.py
test/sql/query.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index 78f94359390f14c3194dea6b2cb1b992fe772d6c..0bab8b9fc32cf69dce70d94ca97b7920cc685d0c 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -24,6 +24,9 @@ CHANGES
 - The no-arg ResultProxy._row_processor() is now the class attribute
   `_process_row`.
 
+- Added support for returning values from inserts and udpates for
+  PostgreSQL 8.2+. [ticket:797]
+
 0.4.0beta6
 ----------
 
index 305e8e83177c7922c94a3bdd6fda69408b4c054b..03b3fd042f15048a64fd1714b3cd9c483954888c 100644 (file)
@@ -10,13 +10,20 @@ PostgreSQL supports partial indexes. To create them pass a posgres_where
 option to the Index constructor::
 
   Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10)
+
+PostgreSQL 8.2+ supports returning a result set from inserts and updates.
+To use this pass the column/expression list to the postgres_returning
+parameter when creating the queries::
+    
+  raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1), 
+    postgres_returning=[empl.c.id, empl.c.salary]).execute().fetchall()
 """
 
 import re, random, warnings, string
 
 from sqlalchemy import sql, schema, exceptions, util
 from sqlalchemy.engine import base, default
-from sqlalchemy.sql import compiler
+from sqlalchemy.sql import compiler, expression
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy import types as sqltypes
 
@@ -198,13 +205,27 @@ def descriptor():
     ]}
 
 SELECT_RE = re.compile(
-    r'\s*(?:SELECT|FETCH)',
+    r'\s*(?:SELECT|FETCH|(UPDATE|INSERT))',
+    re.I | re.UNICODE)
+
+RETURNING_RE = re.compile(
+    'RETURNING',
+    re.I | re.UNICODE)
+
+# This finds if the RETURNING is not inside a quoted/commented values. Handles string literals,
+# quoted identifiers, dollar quotes, SQL comments and C style multiline comments. This does not
+# handle correctly nested C style quotes, lets hope no one does the following:
+# UPDATE tbl SET x=y /* foo /* bar */ RETURNING */
+RETURNING_QUOTED_RE = re.compile(
+    '\\s*(?:UPDATE|INSERT)\\s(?:[^\'"$/-]|-(?!-)|/(?!\\*)|"(?:[^"]|"")*"|\'(?:[^\']|\'\')*\'|\\$(?P<dquote>[^$]*)\\$.*?\\$(?P=dquote)\\$|--[^\n]*\n|/\\*([^*]|\\*(?!/))*\\*/)*\\sRETURNING',
     re.I | re.UNICODE)
 
 class PGExecutionContext(default.DefaultExecutionContext):
 
     def is_select(self):
-        return SELECT_RE.match(self.statement)
+        m = SELECT_RE.match(self.statement)
+        return m and (not m.group(1) or (RETURNING_RE.search(self.statement)
+           and RETURNING_QUOTED_RE.match(self.statement)))
         
     def create_cursor(self):
         # executing a default or Sequence standalone creates an execution context without a statement.  
@@ -598,6 +619,29 @@ class PGCompiler(compiler.DefaultCompiler):
         else:
             return super(PGCompiler, self).for_update_clause(select)
 
+    def _append_returning(self, text, stmt):
+        returning_cols = stmt.kwargs.get('postgres_returning', None)
+        if returning_cols:
+            def flatten_columnlist(collist):
+                for c in collist:
+                    if isinstance(c, expression.Selectable):
+                        for co in c.columns:
+                            yield co
+                    else:
+                        yield c
+            columns = [self.process(c) for c in flatten_columnlist(returning_cols)]
+            text += ' RETURNING ' + string.join(columns, ', ')
+        
+        return text
+
+    def visit_update(self, update_stmt):
+        text = super(PGCompiler, self).visit_update(update_stmt)
+        return self._append_returning(text, update_stmt)
+
+    def visit_insert(self, insert_stmt):
+        text = super(PGCompiler, self).visit_insert(insert_stmt)
+        return self._append_returning(text, insert_stmt)
+
 class PGSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column)
index 9a9cf65d22118371647e572ec01b4976921b5fb1..6f3ee94abd18faa952d91d9a85f0e4dfde501a38 100644 (file)
@@ -249,7 +249,7 @@ def subquery(alias, *args, **kwargs):
 
     return Select(*args, **kwargs).alias(alias)
 
-def insert(table, values=None, inline=False):
+def insert(table, values=None, inline=False, **kwargs):
     """Return an [sqlalchemy.sql.expression#Insert] clause element.
 
     Similar functionality is available via the ``insert()`` method on
@@ -287,9 +287,9 @@ def insert(table, values=None, inline=False):
     against the ``INSERT`` statement.
     """
 
-    return Insert(table, values, inline=inline)
+    return Insert(table, values, inline=inline, **kwargs)
 
-def update(table, whereclause=None, values=None, inline=False):
+def update(table, whereclause=None, values=None, inline=False, **kwargs):
     """Return an [sqlalchemy.sql.expression#Update] clause element.
 
     Similar functionality is available via the ``update()`` method on
@@ -332,7 +332,7 @@ def update(table, whereclause=None, values=None, inline=False):
     against the ``UPDATE`` statement.
     """
 
-    return Update(table, whereclause=whereclause, values=values, inline=inline)
+    return Update(table, whereclause=whereclause, values=values, inline=inline, **kwargs)
 
 def delete(table, whereclause = None, **kwargs):
     """Return a [sqlalchemy.sql.expression#Delete] clause element.
@@ -2699,11 +2699,11 @@ class TableClause(FromClause):
     def select(self, whereclause = None, **params):
         return select([self], whereclause, **params)
 
-    def insert(self, values=None, inline=False):
-        return insert(self, values=values, inline=inline)
+    def insert(self, values=None, inline=False, **kwargs):
+        return insert(self, values=values, inline=inline, **kwargs)
 
-    def update(self, whereclause=None, values=None, inline=False):
-        return update(self, whereclause=whereclause, values=values, inline=inline)
+    def update(self, whereclause=None, values=None, inline=False, **kwargs):
+        return update(self, whereclause=whereclause, values=values, inline=inline, **kwargs)
 
     def delete(self, whereclause=None):
         return delete(self, whereclause)
@@ -3356,12 +3356,14 @@ class _UpdateBase(ClauseElement):
         return self.table.bind
 
 class Insert(_UpdateBase):
-    def __init__(self, table, values=None, inline=False):
+    def __init__(self, table, values=None, inline=False, **kwargs):
         self.table = table
         self.select = None
         self.inline=inline
         self.parameters = self._process_colparams(values)
 
+        self.kwargs = kwargs
+
     def get_children(self, **kwargs):
         if self.select is not None:
             return self.select,
@@ -3383,12 +3385,14 @@ class Insert(_UpdateBase):
         return u
 
 class Update(_UpdateBase):
-    def __init__(self, table, whereclause, values=None, inline=False):
+    def __init__(self, table, whereclause, values=None, inline=False, **kwargs):
         self.table = table
         self._whereclause = whereclause
         self.inline = inline
         self.parameters = self._process_colparams(values)
 
+        self.kwargs = kwargs
+
     def get_children(self, **kwargs):
         if self._whereclause is not None:
             return self._whereclause,
index a519dd974bc76084ed163a3e2dc07e5c2e577778..b4afbbade012d2600d90f42c70352bfa6be0744b 100644 (file)
@@ -563,7 +563,51 @@ class QueryTest(PersistTest):
         s = users.select(users.c.user_name.in_() == None)
         r = s.execute().fetchall()
         assert len(r) == 1
-        
+
+    @testing.supported('postgres')
+    def test_update_returning(self):
+        meta = MetaData(testbase.db)
+        table = Table('tables', meta, 
+            Column('id', Integer, primary_key=True),
+            Column('persons', Integer),
+            Column('full', Boolean)
+        )
+        table.create()
+        try:
+            table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
+            
+            result = table.update(table.c.persons > 4, dict(full=True), postgres_returning=[table.c.id]).execute()
+            self.assertEqual(result.fetchall(), [(1,)])
+            
+            result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
+            self.assertEqual(result2.fetchall(), [(1,True),(2,False)])
+        finally:
+            table.drop()
+
+    @testing.supported('postgres')
+    def test_insert_returning(self):
+        meta = MetaData(testbase.db)
+        table = Table('tables', meta, 
+            Column('id', Integer, primary_key=True),
+            Column('persons', Integer),
+            Column('full', Boolean)
+        )
+        table.create()
+        try:
+            result = table.insert(postgres_returning=[table.c.id]).execute({'persons': 1, 'full': False})
+            
+            self.assertEqual(result.fetchall(), [(1,)])
+            
+            # Multiple inserts only return the last row
+            result2 = table.insert(postgres_returning=[table]).execute(
+                 [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
+             
+            self.assertEqual(result2.fetchall(), [(3,3,True)])
+            
+            result3 = table.insert(postgres_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False})
+            self.assertEqual([dict(row) for row in result3], [{'double_id':8}])
+        finally:
+            table.drop()
 
 class CompoundTest(PersistTest):
     """test compound statements like UNION, INTERSECT, particularly their ability to nest on
index 1114f17358fa8bae710df8bcdc995c7264992f87..4cdac97d87917cafedf62c36f185bdf06afac09f 100644 (file)
@@ -1184,7 +1184,35 @@ class CRUDTest(SQLCompileTest):
         s = select([table2.c.othername], table2.c.otherid == table1.c.myid)
         u = table1.update(table1.c.name==s)
         self.assert_compile(u, "UPDATE mytable SET myid=:myid, name=:name, description=:description WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)")
+    
+    @testing.supported('postgres')
+    def testupdatereturning(self):
+        dialect = postgres.dialect()
+        
+        u = update(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name])
+        self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect)
+        
+        u = update(table1, values=dict(name='foo'), postgres_returning=[table1])
+        self.assert_compile(u, "UPDATE mytable SET name=%(name)s "\
+            "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect)
+        
+        u = update(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)])
+        self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name)", dialect=dialect)
         
+    @testing.supported('postgres')
+    def testinsertreturning(self):
+        dialect = postgres.dialect()
+        
+        i = insert(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name])
+        self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect)
+        
+        i = insert(table1, values=dict(name='foo'), postgres_returning=[table1])
+        self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) "\
+            "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect)
+        
+        i = insert(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)])
+        self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect)
+    
     def testdelete(self):
         self.assert_compile(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid")