From e82ca71cc5c4175f071cdd72207ec04e58a6498c Mon Sep 17 00:00:00 2001 From: Ants Aasma Date: Tue, 2 Oct 2007 23:57:54 +0000 Subject: [PATCH] add support for returning results from inserts and updates for postgresql 8.2+. [ticket:797] --- CHANGES | 3 ++ lib/sqlalchemy/databases/postgres.py | 50 ++++++++++++++++++++++++++-- lib/sqlalchemy/sql/expression.py | 24 +++++++------ test/sql/query.py | 46 ++++++++++++++++++++++++- test/sql/select.py | 28 ++++++++++++++++ 5 files changed, 137 insertions(+), 14 deletions(-) diff --git a/CHANGES b/CHANGES index 78f9435939..0bab8b9fc3 100644 --- a/CHANGES +++ b/CHANGES @@ -24,6 +24,9 @@ CHANGES - The no-arg ResultProxy._row_processor() is now the class attribute `_process_row`. +- Added support for returning values from inserts and udpates for + PostgreSQL 8.2+. [ticket:797] + 0.4.0beta6 ---------- diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 305e8e8317..03b3fd042f 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -10,13 +10,20 @@ PostgreSQL supports partial indexes. To create them pass a posgres_where option to the Index constructor:: Index('my_index', my_table.c.id, postgres_where=tbl.c.value > 10) + +PostgreSQL 8.2+ supports returning a result set from inserts and updates. +To use this pass the column/expression list to the postgres_returning +parameter when creating the queries:: + + raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1), + postgres_returning=[empl.c.id, empl.c.salary]).execute().fetchall() """ import re, random, warnings, string from sqlalchemy import sql, schema, exceptions, util from sqlalchemy.engine import base, default -from sqlalchemy.sql import compiler +from sqlalchemy.sql import compiler, expression from sqlalchemy.sql import operators as sql_operators from sqlalchemy import types as sqltypes @@ -198,13 +205,27 @@ def descriptor(): ]} SELECT_RE = re.compile( - r'\s*(?:SELECT|FETCH)', + r'\s*(?:SELECT|FETCH|(UPDATE|INSERT))', + re.I | re.UNICODE) + +RETURNING_RE = re.compile( + 'RETURNING', + re.I | re.UNICODE) + +# This finds if the RETURNING is not inside a quoted/commented values. Handles string literals, +# quoted identifiers, dollar quotes, SQL comments and C style multiline comments. This does not +# handle correctly nested C style quotes, lets hope no one does the following: +# UPDATE tbl SET x=y /* foo /* bar */ RETURNING */ +RETURNING_QUOTED_RE = re.compile( + '\\s*(?:UPDATE|INSERT)\\s(?:[^\'"$/-]|-(?!-)|/(?!\\*)|"(?:[^"]|"")*"|\'(?:[^\']|\'\')*\'|\\$(?P[^$]*)\\$.*?\\$(?P=dquote)\\$|--[^\n]*\n|/\\*([^*]|\\*(?!/))*\\*/)*\\sRETURNING', re.I | re.UNICODE) class PGExecutionContext(default.DefaultExecutionContext): def is_select(self): - return SELECT_RE.match(self.statement) + m = SELECT_RE.match(self.statement) + return m and (not m.group(1) or (RETURNING_RE.search(self.statement) + and RETURNING_QUOTED_RE.match(self.statement))) def create_cursor(self): # executing a default or Sequence standalone creates an execution context without a statement. @@ -598,6 +619,29 @@ class PGCompiler(compiler.DefaultCompiler): else: return super(PGCompiler, self).for_update_clause(select) + def _append_returning(self, text, stmt): + returning_cols = stmt.kwargs.get('postgres_returning', None) + if returning_cols: + def flatten_columnlist(collist): + for c in collist: + if isinstance(c, expression.Selectable): + for co in c.columns: + yield co + else: + yield c + columns = [self.process(c) for c in flatten_columnlist(returning_cols)] + text += ' RETURNING ' + string.join(columns, ', ') + + return text + + def visit_update(self, update_stmt): + text = super(PGCompiler, self).visit_update(update_stmt) + return self._append_returning(text, update_stmt) + + def visit_insert(self, insert_stmt): + text = super(PGCompiler, self).visit_insert(insert_stmt) + return self._append_returning(text, insert_stmt) + class PGSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 9a9cf65d22..6f3ee94abd 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -249,7 +249,7 @@ def subquery(alias, *args, **kwargs): return Select(*args, **kwargs).alias(alias) -def insert(table, values=None, inline=False): +def insert(table, values=None, inline=False, **kwargs): """Return an [sqlalchemy.sql.expression#Insert] clause element. Similar functionality is available via the ``insert()`` method on @@ -287,9 +287,9 @@ def insert(table, values=None, inline=False): against the ``INSERT`` statement. """ - return Insert(table, values, inline=inline) + return Insert(table, values, inline=inline, **kwargs) -def update(table, whereclause=None, values=None, inline=False): +def update(table, whereclause=None, values=None, inline=False, **kwargs): """Return an [sqlalchemy.sql.expression#Update] clause element. Similar functionality is available via the ``update()`` method on @@ -332,7 +332,7 @@ def update(table, whereclause=None, values=None, inline=False): against the ``UPDATE`` statement. """ - return Update(table, whereclause=whereclause, values=values, inline=inline) + return Update(table, whereclause=whereclause, values=values, inline=inline, **kwargs) def delete(table, whereclause = None, **kwargs): """Return a [sqlalchemy.sql.expression#Delete] clause element. @@ -2699,11 +2699,11 @@ class TableClause(FromClause): def select(self, whereclause = None, **params): return select([self], whereclause, **params) - def insert(self, values=None, inline=False): - return insert(self, values=values, inline=inline) + def insert(self, values=None, inline=False, **kwargs): + return insert(self, values=values, inline=inline, **kwargs) - def update(self, whereclause=None, values=None, inline=False): - return update(self, whereclause=whereclause, values=values, inline=inline) + def update(self, whereclause=None, values=None, inline=False, **kwargs): + return update(self, whereclause=whereclause, values=values, inline=inline, **kwargs) def delete(self, whereclause=None): return delete(self, whereclause) @@ -3356,12 +3356,14 @@ class _UpdateBase(ClauseElement): return self.table.bind class Insert(_UpdateBase): - def __init__(self, table, values=None, inline=False): + def __init__(self, table, values=None, inline=False, **kwargs): self.table = table self.select = None self.inline=inline self.parameters = self._process_colparams(values) + self.kwargs = kwargs + def get_children(self, **kwargs): if self.select is not None: return self.select, @@ -3383,12 +3385,14 @@ class Insert(_UpdateBase): return u class Update(_UpdateBase): - def __init__(self, table, whereclause, values=None, inline=False): + def __init__(self, table, whereclause, values=None, inline=False, **kwargs): self.table = table self._whereclause = whereclause self.inline = inline self.parameters = self._process_colparams(values) + self.kwargs = kwargs + def get_children(self, **kwargs): if self._whereclause is not None: return self._whereclause, diff --git a/test/sql/query.py b/test/sql/query.py index a519dd974b..b4afbbade0 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -563,7 +563,51 @@ class QueryTest(PersistTest): s = users.select(users.c.user_name.in_() == None) r = s.execute().fetchall() assert len(r) == 1 - + + @testing.supported('postgres') + def test_update_returning(self): + meta = MetaData(testbase.db) + table = Table('tables', meta, + Column('id', Integer, primary_key=True), + Column('persons', Integer), + Column('full', Boolean) + ) + table.create() + try: + table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}]) + + result = table.update(table.c.persons > 4, dict(full=True), postgres_returning=[table.c.id]).execute() + self.assertEqual(result.fetchall(), [(1,)]) + + result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute() + self.assertEqual(result2.fetchall(), [(1,True),(2,False)]) + finally: + table.drop() + + @testing.supported('postgres') + def test_insert_returning(self): + meta = MetaData(testbase.db) + table = Table('tables', meta, + Column('id', Integer, primary_key=True), + Column('persons', Integer), + Column('full', Boolean) + ) + table.create() + try: + result = table.insert(postgres_returning=[table.c.id]).execute({'persons': 1, 'full': False}) + + self.assertEqual(result.fetchall(), [(1,)]) + + # Multiple inserts only return the last row + result2 = table.insert(postgres_returning=[table]).execute( + [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}]) + + self.assertEqual(result2.fetchall(), [(3,3,True)]) + + result3 = table.insert(postgres_returning=[(table.c.id*2).label('double_id')]).execute({'persons': 4, 'full': False}) + self.assertEqual([dict(row) for row in result3], [{'double_id':8}]) + finally: + table.drop() class CompoundTest(PersistTest): """test compound statements like UNION, INTERSECT, particularly their ability to nest on diff --git a/test/sql/select.py b/test/sql/select.py index 1114f17358..4cdac97d87 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -1184,7 +1184,35 @@ class CRUDTest(SQLCompileTest): s = select([table2.c.othername], table2.c.otherid == table1.c.myid) u = table1.update(table1.c.name==s) self.assert_compile(u, "UPDATE mytable SET myid=:myid, name=:name, description=:description WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)") + + @testing.supported('postgres') + def testupdatereturning(self): + dialect = postgres.dialect() + + u = update(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name]) + self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING mytable.myid, mytable.name", dialect=dialect) + + u = update(table1, values=dict(name='foo'), postgres_returning=[table1]) + self.assert_compile(u, "UPDATE mytable SET name=%(name)s "\ + "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect) + + u = update(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)]) + self.assert_compile(u, "UPDATE mytable SET name=%(name)s RETURNING length(mytable.name)", dialect=dialect) + @testing.supported('postgres') + def testinsertreturning(self): + dialect = postgres.dialect() + + i = insert(table1, values=dict(name='foo'), postgres_returning=[table1.c.myid, table1.c.name]) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect) + + i = insert(table1, values=dict(name='foo'), postgres_returning=[table1]) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) "\ + "RETURNING mytable.myid, mytable.name, mytable.description", dialect=dialect) + + i = insert(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)]) + self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect) + def testdelete(self): self.assert_compile(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid") -- 2.47.3