]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
oracle+zxjdbc returning support
authorPhilip Jenvey <pjenvey@underboss.org>
Tue, 18 Aug 2009 05:28:05 +0000 (05:28 +0000)
committerPhilip Jenvey <pjenvey@underboss.org>
Tue, 18 Aug 2009 05:28:05 +0000 (05:28 +0000)
lib/sqlalchemy/dialects/oracle/zxjdbc.py
lib/sqlalchemy/test/assertsql.py
lib/sqlalchemy/test/requires.py
test/engine/test_execute.py
test/sql/test_query.py
test/sql/test_returning.py

index c2143138a08d2fea0beb31c4d7e78cc56d53dcbb..8969ebdcf1f2d7e5095ae5bdc96eb59ee83beb30 100644 (file)
@@ -1,16 +1,22 @@
-"""Support for the Oracle database via the zxjdbc JDBC connector."""
+"""Support for the Oracle database via the zxjdbc JDBC connector.
+
+JDBC Driver
+-----------
+
+The official Oracle JDBC driver is at
+http://www.oracle.com/technology/software/tech/java/sqlj_jdbc/index.html.
+
+"""
 import decimal
 import re
 
-try:
-    from com.ziclix.python.sql.handler import OracleDataHandler
-except ImportError:
-    OracleDataHandler = None
-
-from sqlalchemy import types as sqltypes, util
+from sqlalchemy import sql, types as sqltypes, util
 from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
-from sqlalchemy.dialects.oracle.base import OracleDialect
-from sqlalchemy.engine.default import DefaultExecutionContext
+from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect
+from sqlalchemy.engine import base, default
+from sqlalchemy.sql import expression
+
+SQLException = zxJDBC = None
 
 class _JDBCDate(sqltypes.Date):
 
@@ -37,21 +43,120 @@ class _JDBCNumeric(sqltypes.Numeric):
             return process
 
 
-class Oracle_jdbcExecutionContext(DefaultExecutionContext):
+class Oracle_jdbcCompiler(OracleCompiler):
+
+    def returning_clause(self, stmt, returning_cols):
+        columnlist = list(expression._select_iterables(returning_cols))
+
+        # within_columns_clause=False so that labels (foo AS bar) don't render
+        columns = [self.process(c, within_columns_clause=False, result_map=self.result_map)
+                   for c in columnlist]
+
+        if not hasattr(self, 'returning_parameters'):
+            self.returning_parameters = []
+
+        binds = []
+        for i, col in enumerate(columnlist):
+            dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
+            self.returning_parameters.append((i + 1, dbtype))
+
+            bindparam = sql.bindparam("ret_%d" % i, value=ReturningParam(dbtype))
+            self.binds[bindparam.key] = bindparam
+            binds.append(self.bindparam_string(self._truncate_bindparam(bindparam)))
+
+        return 'RETURNING ' + ', '.join(columns) +  " INTO " + ", ".join(binds)
+
+
+class Oracle_jdbcExecutionContext(default.DefaultExecutionContext):
+
+    def pre_exec(self):
+        if hasattr(self.compiled, 'returning_parameters'):
+            # prepare a zxJDBC statement so we can grab its underlying
+            # OraclePreparedStatement's getReturnResultSet later
+            self.statement = self.cursor.prepare(self.statement)
+
+    def get_result_proxy(self):
+        if hasattr(self.compiled, 'returning_parameters'):
+            rrs = None
+            try:
+                try:
+                    rrs = self.statement.__statement__.getReturnResultSet()
+                    rrs.next()
+                except SQLException, sqle:
+                    msg = '%s [SQLCode: %d]' % (sqle.getMessage(), sqle.getErrorCode())
+                    if sqle.getSQLState() is not None:
+                        msg += ' [SQLState: %s]' % sqle.getSQLState()
+                    raise zxJDBC.Error(msg)
+                else:
+                    row = tuple(self.cursor.datahandler.getPyObject(rrs, index, dbtype)
+                                for index, dbtype in self.compiled.returning_parameters)
+                    return ReturningResultProxy(self, row)
+            finally:
+                if rrs is not None:
+                    try:
+                        rrs.close()
+                    except SQLException:
+                        pass
+                self.statement.close()
+
+        return base.ResultProxy(self)
 
     def create_cursor(self):
         cursor = self._connection.connection.cursor()
-        cursor.cursor.datahandler = OracleDataHandler(cursor.cursor.datahandler)
+        cursor.cursor.datahandler = self.dialect.DataHandler(cursor.cursor.datahandler)
         return cursor
 
 
+class ReturningResultProxy(base.FullyBufferedResultProxy):
+
+    """ResultProxy backed by the RETURNING ResultSet results."""
+
+    def __init__(self, context, returning_row):
+        self._returning_row = returning_row
+        super(ReturningResultProxy, self).__init__(context)
+
+    def _cursor_description(self):
+        returning = self.context.compiled.returning
+
+        ret = []
+        for c in returning:
+            if hasattr(c, 'name'):
+                ret.append((c.name, c.type))
+            else:
+                ret.append((c.anon_label, c.type))
+        return ret
+
+    def _buffer_rows(self):
+        return [self._returning_row]
+
+
+class ReturningParam(object):
+
+    """A bindparam value representing a RETURNING parameter.
+
+    Specially handled by OracleReturningDataHandler.
+    """
+
+    def __init__(self, type):
+        self.type = type
+
+    def __eq__(self, other):
+        if isinstance(other, ReturningParam):
+            return self.type == other.type
+        return NotImplemented
+
+    def __repr__(self):
+        kls = self.__class__
+        return '<%s.%s object at 0x%x type=%s>' % (kls.__module__, kls.__name__, id(self),
+                                                   self.type)
+
+
 class Oracle_jdbc(ZxJDBCConnector, OracleDialect):
+    statement_compiler = Oracle_jdbcCompiler
     execution_ctx_cls = Oracle_jdbcExecutionContext
     jdbc_db_name = 'oracle'
     jdbc_driver_name = 'oracle.jdbc.OracleDriver'
 
-    implicit_returning = False
-
     colspecs = util.update_copy(
         OracleDialect.colspecs,
         {
@@ -60,9 +165,28 @@ class Oracle_jdbc(ZxJDBCConnector, OracleDialect):
         }
     )
 
+    def __init__(self, *args, **kwargs):
+        super(Oracle_jdbc, self).__init__(*args, **kwargs)
+        global SQLException, zxJDBC
+        from java.sql import SQLException
+        from com.ziclix.python.sql import zxJDBC
+        from com.ziclix.python.sql.handler import OracleDataHandler
+        class OracleReturningDataHandler(OracleDataHandler):
+
+            """zxJDBC DataHandler that specially handles ReturningParam."""
+
+            def setJDBCObject(self, statement, index, object, dbtype=None):
+                if type(object) is ReturningParam:
+                    statement.registerReturnParameter(index, object.type)
+                elif dbtype is None:
+                    OracleDataHandler.setJDBCObject(self, statement, index, object)
+                else:
+                    OracleDataHandler.setJDBCObject(self, statement, index, object, dbtype)
+        self.DataHandler = OracleReturningDataHandler
+
     def initialize(self, connection):
         super(Oracle_jdbc, self).initialize(connection)
-        self.implicit_returning = False
+        self.implicit_returning = connection.connection.driverversion >= '10.2'
 
     def _create_jdbc_url(self, url):
         return 'jdbc:oracle:thin:@%s:%s:%s' % (url.host, url.port or 1521, url.database)
index 1af28794eda0cbc21145aae77ac820191ed77443..6dbc95b784fafeb8f428e76a65cf06d23b38f8f5 100644 (file)
@@ -216,6 +216,9 @@ class AllOf(AssertRule):
         return len(self.rules) == 0
         
 def _process_engine_statement(query, context):
+    if util.jython:
+        # oracle+zxjdbc passes a PyStatement when returning into
+        query = unicode(query)
     if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'):
         query = query[:-25]
     
index c1f8d31689dc90e68ad4a9d45ede7be85629065d..f3f4ec1911c9f0a7243408d282429dee322ba1cd 100644 (file)
@@ -140,7 +140,6 @@ def returning(fn):
         no_support('maxdb', 'not supported by database'),
         no_support('sybase', 'not supported by database'),
         no_support('informix', 'not supported by database'),
-        no_support('oracle+zxjdbc', 'FIXME: tricky; currently broken'),
     )
     
 def two_phase_transactions(fn):
index c47f038c40dd14174f1104d449baa3cc85956d30..7ec4124a9899997da13ada4b7485a5668d290b38 100644 (file)
@@ -108,7 +108,7 @@ class ProxyConnectionTest(TestBase):
             def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
                 print "CE", statement, parameters
                 cursor_stmts.append(
-                    (statement, parameters, None)
+                    (str(statement), parameters, None)
                 )
                 return execute(cursor, statement, parameters, context)
         
@@ -148,7 +148,7 @@ class ProxyConnectionTest(TestBase):
                 ("DROP TABLE t1", {}, None)
             ]
 
-            if True: # or engine.dialect.preexecute_pk_sequences:
+            if not testing.against('oracle+zxjdbc'): # or engine.dialect.preexecute_pk_sequences:
                 cursor = [
                     ("CREATE TABLE t1", {}, ()),
                     ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']),
@@ -158,10 +158,14 @@ class ProxyConnectionTest(TestBase):
                     ("DROP TABLE t1", {}, ())
                 ]
             else:
+                insert2_params = [6, 'Foo']
+                if testing.against('oracle+zxjdbc'):
+                    from sqlalchemy.dialects.oracle.zxjdbc import ReturningParam
+                    insert2_params.append(ReturningParam(12))
                 cursor = [
                     ("CREATE TABLE t1", {}, ()),
                     ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']),
-                    ("INSERT INTO t1 (c1, c2)", {'c1': 6, "lower_2":"Foo"}, [6, "Foo"]),  # bind param name 'lower_2' might be incorrect
+                    ("INSERT INTO t1 (c1, c2)", {'c1': 6, "lower_2":"Foo"}, insert2_params),  # bind param name 'lower_2' might be incorrect
                     ("select * from t1", {}, ()),
                     ("DROP TABLE t1", {}, ())
                 ]
index 934bdadbee6fef79dc65f4b8bee1984209767374..3222ff6ef42b2823a0cb707cc330561af941bfc4 100644 (file)
@@ -80,8 +80,7 @@ class QueryTest(TestBase):
                     ret[c.key] = row[c]
             return ret
 
-        if (testing.against('firebird', 'postgresql', 'oracle', 'mssql') and
-            not testing.against('oracle+zxjdbc')):
+        if testing.against('firebird', 'postgresql', 'oracle', 'mssql'):
             test_engines = [
                 engines.testing_engine(options={'implicit_returning':False}),
                 engines.testing_engine(options={'implicit_returning':True}),
@@ -168,8 +167,7 @@ class QueryTest(TestBase):
         eq_(r.inserted_primary_key, [12, 1])
 
     def test_autoclose_on_insert(self):
-        if (testing.against('firebird', 'postgresql', 'oracle', 'mssql') and
-            not testing.against('oracle+zxjdbc')):
+        if testing.against('firebird', 'postgresql', 'oracle', 'mssql'):
             test_engines = [
                 engines.testing_engine(options={'implicit_returning':False}),
                 engines.testing_engine(options={'implicit_returning':True}),
index 474e0b369215ad9b81c5ae582b32f492ff5e23ba..02d906dd846983c0b3eed39135a75ffc7780f317 100644 (file)
@@ -5,7 +5,7 @@ from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.types import TypeDecorator
 
 class ReturningTest(TestBase, AssertsExecutionResults):
-    __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access', 'oracle+zxjdbc')
+    __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access')
 
     def setup(self):
         meta = MetaData(testing.db)
@@ -61,6 +61,7 @@ class ReturningTest(TestBase, AssertsExecutionResults):
         assert row['lala'] == 6
 
     @testing.fails_on('firebird', "fb/kintersbasdb can't handle the bind params")
+    @testing.fails_on('oracle+zxjdbc', "JDBC driver bug")
     @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
     @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
     def test_anon_expressions(self):
@@ -92,6 +93,8 @@ class ReturningTest(TestBase, AssertsExecutionResults):
 
         eq_(result.fetchall(), [(1,)])
 
+        @testing.crashes('oracle+zxjdbc', 'Triggers a "No more data to read from socket" and '
+                         'prevents table from being dropped')
         @testing.fails_on('postgresql', '')
         @testing.fails_on('oracle', '')
         def test_executemany():
@@ -109,7 +112,8 @@ class ReturningTest(TestBase, AssertsExecutionResults):
         test_executemany()
 
         result3 = table.insert().returning(table.c.id).execute({'persons': 4, 'full': False})
-        eq_([dict(row) for row in result3], [{'id': 4}])
+        next = testing.against('oracle+zxjdbc') and 2 or 4
+        eq_([dict(row) for row in result3], [{'id': next}])
     
         
     @testing.exclude('firebird', '<', (2, 1), '2.1+ feature')
@@ -137,7 +141,7 @@ class ReturningTest(TestBase, AssertsExecutionResults):
         eq_(result2.fetchall(), [(2,False),])
 
 class SequenceReturningTest(TestBase):
-    __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access', 'mssql', 'oracle+zxjdbc')
+    __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access', 'mssql')
 
     def setup(self):
         meta = MetaData(testing.db)
@@ -160,7 +164,7 @@ class SequenceReturningTest(TestBase):
 class KeyReturningTest(TestBase, AssertsExecutionResults):
     """test returning() works with columns that define 'key'."""
     
-    __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access', 'oracle+zxjdbc')
+    __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access')
 
     def setup(self):
         meta = MetaData(testing.db)