]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support Firebird 2.0+ RETURNING
authorLele Gaifax <lele@metapensiero.it>
Wed, 14 May 2008 15:31:29 +0000 (15:31 +0000)
committerLele Gaifax <lele@metapensiero.it>
Wed, 14 May 2008 15:31:29 +0000 (15:31 +0000)
CHANGES
lib/sqlalchemy/databases/firebird.py
test/dialect/firebird.py

diff --git a/CHANGES b/CHANGES
index 3f6ddbd03c4cf46835fdc96f4ec21c5ea6ce112d..3eacec328843a80943f7fc3cd9f84cf8185c6eaa 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -66,6 +66,9 @@ user_defined_state
       incompatible; previously the extensions of last mapper defined
       would receive these events.
 
+- firebird
+    - Added support for returning values from inserts (2.0+ only),
+      updates and deletes (2.1+ only).
 
 0.4.6
 =====
@@ -1573,7 +1576,7 @@ user_defined_state
 - The no-arg ResultProxy._row_processor() is now the class attribute
   `_process_row`.
 
-- Added support for returning values from inserts and udpates for
+- Added support for returning values from inserts and updates for
   PostgreSQL 8.2+. [ticket:797]
 
 - PG reflection, upon seeing the default schema name being used explicitly
index 948e001d5358e891ff11b58debff23cfbf255e3c..d3662ccbfd90b6184c2156be955796a6512e5d88 100644 (file)
@@ -78,6 +78,18 @@ connections are active, the following setting may alleviate the problem::
   # Force SA to use a single connection per thread
   dialect.poolclass = pool.SingletonThreadPool
 
+RETURNING support
+-----------------
+
+Firebird 2.0 supports returning a result set from inserts, and 2.1 extends
+that to deletes and updates.
+
+To use this pass the column/expression list to the ``firebird_returning``
+parameter when creating the queries::
+
+  raises = tbl.update(empl.c.sales > 100, values=dict(salary=empl.c.salary * 1.1),
+                      firebird_returning=[empl.c.id, empl.c.salary]).execute().fetchall()
+
 
 .. [#] Well, that is not the whole story, as the client may still ask
        a different (lower) dialect...
@@ -87,7 +99,7 @@ connections are active, the following setting may alleviate the problem::
 """
 
 
-import datetime
+import datetime, re
 
 from sqlalchemy import exc, schema, types as sqltypes, sql, util
 from sqlalchemy.engine import base, default
@@ -261,8 +273,44 @@ def descriptor():
     ]}
 
 
+SELECT_RE = re.compile(
+    r'\s*(?:SELECT|(UPDATE|INSERT|DELETE))',
+    re.I | re.UNICODE)
+
+RETURNING_RE = re.compile(
+    'RETURNING',
+    re.I | re.UNICODE)
+
+# This finds if the RETURNING is not inside a quoted/commented values. Handles string literals,
+# quoted identifiers, dollar quotes, SQL comments and C style multiline comments. This does not
+# handle correctly nested C style quotes, lets hope no one does the following:
+# UPDATE tbl SET x=y /* foo /* bar */ RETURNING */
+RETURNING_QUOTED_RE = re.compile(
+    """\s*(?:UPDATE|INSERT|DELETE)\s
+        (?: # handle quoted and commented tokens separately
+            [^'"$/-] # non quote/comment character
+            | -(?!-) # a dash that does not begin a comment
+            | /(?!\*) # a slash that does not begin a comment
+            | "(?:[^"]|"")*" # quoted literal
+            | '(?:[^']|'')*' # quoted string
+            | --[^\\n]*(?=\\n) # SQL comment, leave out line ending as that counts as whitespace
+                               # for the returning token
+            | /\*([^*]|\*(?!/))*\*/ # C style comment, doesn't handle nesting
+        )*
+        \sRETURNING\s""", re.I | re.UNICODE | re.VERBOSE)
+
+RETURNING_KW_NAME = 'firebird_returning'
+
 class FBExecutionContext(default.DefaultExecutionContext):
-    pass
+    def returns_rows_text(self, statement):
+        m = SELECT_RE.match(statement)
+        return m and (not m.group(1) or (RETURNING_RE.search(statement)
+                                         and RETURNING_QUOTED_RE.match(statement)))
+
+    def returns_rows_compiled(self, compiled):
+        return (isinstance(compiled.statement, sql.expression.Selectable) or
+                ((compiled.isupdate or compiled.isinsert or compiler.isdelete) and
+                 RETURNING_KW_NAME in compiled.statement.kwargs))
 
 
 class FBDialect(default.DefaultDialect):
@@ -629,6 +677,41 @@ class FBCompiler(sql.compiler.DefaultCompiler):
             return self.LENGTH_FUNCTION_NAME + '%(expr)s'
         return super(FBCompiler, self).function_string(func)
 
+    def _append_returning(self, text, stmt):
+        returning_cols = stmt.kwargs[RETURNING_KW_NAME]
+        def flatten_columnlist(collist):
+            for c in collist:
+                if isinstance(c, sql.expression.Selectable):
+                    for co in c.columns:
+                        yield co
+                else:
+                    yield c
+        columns = [self.process(c, render_labels=True)
+                   for c in flatten_columnlist(returning_cols)]
+        text += ' RETURNING ' + ', '.join(columns)
+        return text
+
+    def visit_update(self, update_stmt):
+        text = super(FBCompiler, self).visit_update(update_stmt)
+        if RETURNING_KW_NAME in update_stmt.kwargs:
+            return self._append_returning(text, update_stmt)
+        else:
+            return text
+
+    def visit_insert(self, insert_stmt):
+        text = super(FBCompiler, self).visit_insert(insert_stmt)
+        if RETURNING_KW_NAME in insert_stmt.kwargs:
+            return self._append_returning(text, insert_stmt)
+        else:
+            return text
+
+    def visit_delete(self, delete_stmt):
+        text = super(FBCompiler, self).visit_delete(delete_stmt)
+        if RETURNING_KW_NAME in delete_stmt.kwargs:
+            return self._append_returning(text, delete_stmt)
+        else:
+            return text
+
 
 class FBSchemaGenerator(sql.compiler.SchemaGenerator):
     """Firebird syntactic idiosincrasies"""
index da6cc697094884679234be5042b4a5e02026fe7e..4b4b9fd7f61c21a051edf90e31ae82f85c0fcf4c 100644 (file)
@@ -89,8 +89,114 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         self.assert_compile(func.substring('abc', 1, 2), "SUBSTRING(:substring_1 FROM :substring_2 FOR :substring_3)")
         self.assert_compile(func.substring('abc', 1), "SUBSTRING(:substring_1 FROM :substring_2)")
 
-class MiscFBTests(TestBase):
+    def test_update_returning(self):
+        table1 = table('mytable',
+            column('myid', Integer),
+            column('name', String(128)),
+            column('description', String(128)),
+        )
+
+        u = update(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name])
+        self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING mytable.myid, mytable.name")
+
+        u = update(table1, values=dict(name='foo'), firebird_returning=[table1])
+        self.assert_compile(u, "UPDATE mytable SET name=:name "\
+            "RETURNING mytable.myid, mytable.name, mytable.description")
+
+        u = update(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)])
+        self.assert_compile(u, "UPDATE mytable SET name=:name RETURNING char_length(mytable.name)")
+
+    def test_insert_returning(self):
+        table1 = table('mytable',
+            column('myid', Integer),
+            column('name', String(128)),
+            column('description', String(128)),
+        )
+
+        i = insert(table1, values=dict(name='foo'), firebird_returning=[table1.c.myid, table1.c.name])
+        self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING mytable.myid, mytable.name")
+
+        i = insert(table1, values=dict(name='foo'), firebird_returning=[table1])
+        self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) "\
+            "RETURNING mytable.myid, mytable.name, mytable.description")
+
+        i = insert(table1, values=dict(name='foo'), firebird_returning=[func.length(table1.c.name)])
+        self.assert_compile(i, "INSERT INTO mytable (name) VALUES (:name) RETURNING char_length(mytable.name)")
+
+
+class ReturningTest(TestBase, AssertsExecutionResults):
+    __only_on__ = 'firebird'
+
+    @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
+    def test_update_returning(self):
+        meta = MetaData(testing.db)
+        table = Table('tables', meta,
+            Column('id', Integer, Sequence('gen_tables_id'), primary_key=True),
+            Column('persons', Integer),
+            Column('full', Boolean)
+        )
+        table.create()
+        try:
+            table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
+
+            result = table.update(table.c.persons > 4, dict(full=True), firebird_returning=[table.c.id]).execute()
+            self.assertEqual(result.fetchall(), [(1,)])
+
+            result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
+            self.assertEqual(result2.fetchall(), [(1,True),(2,False)])
+        finally:
+            table.drop()
 
+    @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+    def test_insert_returning(self):
+        meta = MetaData(testing.db)
+        table = Table('tables', meta,
+            Column('id', Integer, Sequence('gen_tables_id'), primary_key=True),
+            Column('persons', Integer),
+            Column('full', Boolean)
+        )
+        table.create()
+        try:
+            result = table.insert(firebird_returning=[table.c.id]).execute({'persons': 1, 'full': False})
+
+            self.assertEqual(result.fetchall(), [(1,)])
+
+            # Multiple inserts only return the last row
+            result2 = table.insert(firebird_returning=[table]).execute(
+                 [{'persons': 2, 'full': False}, {'persons': 3, 'full': True}])
+
+            self.assertEqual(result2.fetchall(), [(3,3,True)])
+
+            result3 = table.insert(firebird_returning=[table.c.id]).execute({'persons': 4, 'full': False})
+            self.assertEqual([dict(row) for row in result3], [{'ID':4}])
+
+            result4 = testing.db.execute('insert into tables (id, persons, "full") values (5, 10, 1) returning persons')
+            self.assertEqual([dict(row) for row in result4], [{'PERSONS': 10}])
+        finally:
+            table.drop()
+
+    @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
+    def test_delete_returning(self):
+        meta = MetaData(testing.db)
+        table = Table('tables', meta,
+            Column('id', Integer, Sequence('gen_tables_id'), primary_key=True),
+            Column('persons', Integer),
+            Column('full', Boolean)
+        )
+        table.create()
+        try:
+            table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])
+
+            result = table.delete(table.c.persons > 4, dict(full=True), firebird_returning=[table.c.id]).execute()
+            self.assertEqual(result.fetchall(), [(1,)])
+
+            result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
+            self.assertEqual(result2.fetchall(), [(2,False),])
+        finally:
+            table.drop()
+
+
+class MiscFBTests(TestBase):
     __only_on__ = 'firebird'
 
     def test_strlen(self):
@@ -117,5 +223,6 @@ class MiscFBTests(TestBase):
         version = testing.db.dialect.server_version_info(testing.db.connect())
         assert len(version) == 3, "Got strange version info: %s" % repr(version)
 
+
 if __name__ == '__main__':
     testenv.main()