From fc59a5e0c4731a29cc3e4852bc3629850e0c9a04 Mon Sep 17 00:00:00 2001 From: Philip Jenvey Date: Tue, 18 Aug 2009 05:28:05 +0000 Subject: [PATCH] oracle+zxjdbc returning support --- lib/sqlalchemy/dialects/oracle/zxjdbc.py | 152 ++++++++++++++++++++--- lib/sqlalchemy/test/assertsql.py | 3 + lib/sqlalchemy/test/requires.py | 1 - test/engine/test_execute.py | 10 +- test/sql/test_query.py | 6 +- test/sql/test_returning.py | 12 +- 6 files changed, 158 insertions(+), 26 deletions(-) diff --git a/lib/sqlalchemy/dialects/oracle/zxjdbc.py b/lib/sqlalchemy/dialects/oracle/zxjdbc.py index c2143138a0..8969ebdcf1 100644 --- a/lib/sqlalchemy/dialects/oracle/zxjdbc.py +++ b/lib/sqlalchemy/dialects/oracle/zxjdbc.py @@ -1,16 +1,22 @@ -"""Support for the Oracle database via the zxjdbc JDBC connector.""" +"""Support for the Oracle database via the zxjdbc JDBC connector. + +JDBC Driver +----------- + +The official Oracle JDBC driver is at +http://www.oracle.com/technology/software/tech/java/sqlj_jdbc/index.html. + +""" import decimal import re -try: - from com.ziclix.python.sql.handler import OracleDataHandler -except ImportError: - OracleDataHandler = None - -from sqlalchemy import types as sqltypes, util +from sqlalchemy import sql, types as sqltypes, util from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector -from sqlalchemy.dialects.oracle.base import OracleDialect -from sqlalchemy.engine.default import DefaultExecutionContext +from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect +from sqlalchemy.engine import base, default +from sqlalchemy.sql import expression + +SQLException = zxJDBC = None class _JDBCDate(sqltypes.Date): @@ -37,21 +43,120 @@ class _JDBCNumeric(sqltypes.Numeric): return process -class Oracle_jdbcExecutionContext(DefaultExecutionContext): +class Oracle_jdbcCompiler(OracleCompiler): + + def returning_clause(self, stmt, returning_cols): + columnlist = list(expression._select_iterables(returning_cols)) + + # within_columns_clause=False so that labels (foo AS bar) don't render + columns = [self.process(c, within_columns_clause=False, result_map=self.result_map) + for c in columnlist] + + if not hasattr(self, 'returning_parameters'): + self.returning_parameters = [] + + binds = [] + for i, col in enumerate(columnlist): + dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + self.returning_parameters.append((i + 1, dbtype)) + + bindparam = sql.bindparam("ret_%d" % i, value=ReturningParam(dbtype)) + self.binds[bindparam.key] = bindparam + binds.append(self.bindparam_string(self._truncate_bindparam(bindparam))) + + return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) + + +class Oracle_jdbcExecutionContext(default.DefaultExecutionContext): + + def pre_exec(self): + if hasattr(self.compiled, 'returning_parameters'): + # prepare a zxJDBC statement so we can grab its underlying + # OraclePreparedStatement's getReturnResultSet later + self.statement = self.cursor.prepare(self.statement) + + def get_result_proxy(self): + if hasattr(self.compiled, 'returning_parameters'): + rrs = None + try: + try: + rrs = self.statement.__statement__.getReturnResultSet() + rrs.next() + except SQLException, sqle: + msg = '%s [SQLCode: %d]' % (sqle.getMessage(), sqle.getErrorCode()) + if sqle.getSQLState() is not None: + msg += ' [SQLState: %s]' % sqle.getSQLState() + raise zxJDBC.Error(msg) + else: + row = tuple(self.cursor.datahandler.getPyObject(rrs, index, dbtype) + for index, dbtype in self.compiled.returning_parameters) + return ReturningResultProxy(self, row) + finally: + if rrs is not None: + try: + rrs.close() + except SQLException: + pass + self.statement.close() + + return base.ResultProxy(self) def create_cursor(self): cursor = self._connection.connection.cursor() - cursor.cursor.datahandler = OracleDataHandler(cursor.cursor.datahandler) + cursor.cursor.datahandler = self.dialect.DataHandler(cursor.cursor.datahandler) return cursor +class ReturningResultProxy(base.FullyBufferedResultProxy): + + """ResultProxy backed by the RETURNING ResultSet results.""" + + def __init__(self, context, returning_row): + self._returning_row = returning_row + super(ReturningResultProxy, self).__init__(context) + + def _cursor_description(self): + returning = self.context.compiled.returning + + ret = [] + for c in returning: + if hasattr(c, 'name'): + ret.append((c.name, c.type)) + else: + ret.append((c.anon_label, c.type)) + return ret + + def _buffer_rows(self): + return [self._returning_row] + + +class ReturningParam(object): + + """A bindparam value representing a RETURNING parameter. + + Specially handled by OracleReturningDataHandler. + """ + + def __init__(self, type): + self.type = type + + def __eq__(self, other): + if isinstance(other, ReturningParam): + return self.type == other.type + return NotImplemented + + def __repr__(self): + kls = self.__class__ + return '<%s.%s object at 0x%x type=%s>' % (kls.__module__, kls.__name__, id(self), + self.type) + + class Oracle_jdbc(ZxJDBCConnector, OracleDialect): + statement_compiler = Oracle_jdbcCompiler execution_ctx_cls = Oracle_jdbcExecutionContext jdbc_db_name = 'oracle' jdbc_driver_name = 'oracle.jdbc.OracleDriver' - implicit_returning = False - colspecs = util.update_copy( OracleDialect.colspecs, { @@ -60,9 +165,28 @@ class Oracle_jdbc(ZxJDBCConnector, OracleDialect): } ) + def __init__(self, *args, **kwargs): + super(Oracle_jdbc, self).__init__(*args, **kwargs) + global SQLException, zxJDBC + from java.sql import SQLException + from com.ziclix.python.sql import zxJDBC + from com.ziclix.python.sql.handler import OracleDataHandler + class OracleReturningDataHandler(OracleDataHandler): + + """zxJDBC DataHandler that specially handles ReturningParam.""" + + def setJDBCObject(self, statement, index, object, dbtype=None): + if type(object) is ReturningParam: + statement.registerReturnParameter(index, object.type) + elif dbtype is None: + OracleDataHandler.setJDBCObject(self, statement, index, object) + else: + OracleDataHandler.setJDBCObject(self, statement, index, object, dbtype) + self.DataHandler = OracleReturningDataHandler + def initialize(self, connection): super(Oracle_jdbc, self).initialize(connection) - self.implicit_returning = False + self.implicit_returning = connection.connection.driverversion >= '10.2' def _create_jdbc_url(self, url): return 'jdbc:oracle:thin:@%s:%s:%s' % (url.host, url.port or 1521, url.database) diff --git a/lib/sqlalchemy/test/assertsql.py b/lib/sqlalchemy/test/assertsql.py index 1af28794ed..6dbc95b784 100644 --- a/lib/sqlalchemy/test/assertsql.py +++ b/lib/sqlalchemy/test/assertsql.py @@ -216,6 +216,9 @@ class AllOf(AssertRule): return len(self.rules) == 0 def _process_engine_statement(query, context): + if util.jython: + # oracle+zxjdbc passes a PyStatement when returning into + query = unicode(query) if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'): query = query[:-25] diff --git a/lib/sqlalchemy/test/requires.py b/lib/sqlalchemy/test/requires.py index c1f8d31689..f3f4ec1911 100644 --- a/lib/sqlalchemy/test/requires.py +++ b/lib/sqlalchemy/test/requires.py @@ -140,7 +140,6 @@ def returning(fn): no_support('maxdb', 'not supported by database'), no_support('sybase', 'not supported by database'), no_support('informix', 'not supported by database'), - no_support('oracle+zxjdbc', 'FIXME: tricky; currently broken'), ) def two_phase_transactions(fn): diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index c47f038c40..7ec4124a98 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -108,7 +108,7 @@ class ProxyConnectionTest(TestBase): def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): print "CE", statement, parameters cursor_stmts.append( - (statement, parameters, None) + (str(statement), parameters, None) ) return execute(cursor, statement, parameters, context) @@ -148,7 +148,7 @@ class ProxyConnectionTest(TestBase): ("DROP TABLE t1", {}, None) ] - if True: # or engine.dialect.preexecute_pk_sequences: + if not testing.against('oracle+zxjdbc'): # or engine.dialect.preexecute_pk_sequences: cursor = [ ("CREATE TABLE t1", {}, ()), ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']), @@ -158,10 +158,14 @@ class ProxyConnectionTest(TestBase): ("DROP TABLE t1", {}, ()) ] else: + insert2_params = [6, 'Foo'] + if testing.against('oracle+zxjdbc'): + from sqlalchemy.dialects.oracle.zxjdbc import ReturningParam + insert2_params.append(ReturningParam(12)) cursor = [ ("CREATE TABLE t1", {}, ()), ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']), - ("INSERT INTO t1 (c1, c2)", {'c1': 6, "lower_2":"Foo"}, [6, "Foo"]), # bind param name 'lower_2' might be incorrect + ("INSERT INTO t1 (c1, c2)", {'c1': 6, "lower_2":"Foo"}, insert2_params), # bind param name 'lower_2' might be incorrect ("select * from t1", {}, ()), ("DROP TABLE t1", {}, ()) ] diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 934bdadbee..3222ff6ef4 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -80,8 +80,7 @@ class QueryTest(TestBase): ret[c.key] = row[c] return ret - if (testing.against('firebird', 'postgresql', 'oracle', 'mssql') and - not testing.against('oracle+zxjdbc')): + if testing.against('firebird', 'postgresql', 'oracle', 'mssql'): test_engines = [ engines.testing_engine(options={'implicit_returning':False}), engines.testing_engine(options={'implicit_returning':True}), @@ -168,8 +167,7 @@ class QueryTest(TestBase): eq_(r.inserted_primary_key, [12, 1]) def test_autoclose_on_insert(self): - if (testing.against('firebird', 'postgresql', 'oracle', 'mssql') and - not testing.against('oracle+zxjdbc')): + if testing.against('firebird', 'postgresql', 'oracle', 'mssql'): test_engines = [ engines.testing_engine(options={'implicit_returning':False}), engines.testing_engine(options={'implicit_returning':True}), diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index 474e0b3692..02d906dd84 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -5,7 +5,7 @@ from sqlalchemy.test.schema import Table, Column from sqlalchemy.types import TypeDecorator class ReturningTest(TestBase, AssertsExecutionResults): - __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access', 'oracle+zxjdbc') + __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access') def setup(self): meta = MetaData(testing.db) @@ -61,6 +61,7 @@ class ReturningTest(TestBase, AssertsExecutionResults): assert row['lala'] == 6 @testing.fails_on('firebird', "fb/kintersbasdb can't handle the bind params") + @testing.fails_on('oracle+zxjdbc', "JDBC driver bug") @testing.exclude('firebird', '<', (2, 0), '2.0+ feature') @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature') def test_anon_expressions(self): @@ -92,6 +93,8 @@ class ReturningTest(TestBase, AssertsExecutionResults): eq_(result.fetchall(), [(1,)]) + @testing.crashes('oracle+zxjdbc', 'Triggers a "No more data to read from socket" and ' + 'prevents table from being dropped') @testing.fails_on('postgresql', '') @testing.fails_on('oracle', '') def test_executemany(): @@ -109,7 +112,8 @@ class ReturningTest(TestBase, AssertsExecutionResults): test_executemany() result3 = table.insert().returning(table.c.id).execute({'persons': 4, 'full': False}) - eq_([dict(row) for row in result3], [{'id': 4}]) + next = testing.against('oracle+zxjdbc') and 2 or 4 + eq_([dict(row) for row in result3], [{'id': next}]) @testing.exclude('firebird', '<', (2, 1), '2.1+ feature') @@ -137,7 +141,7 @@ class ReturningTest(TestBase, AssertsExecutionResults): eq_(result2.fetchall(), [(2,False),]) class SequenceReturningTest(TestBase): - __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access', 'mssql', 'oracle+zxjdbc') + __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access', 'mssql') def setup(self): meta = MetaData(testing.db) @@ -160,7 +164,7 @@ class SequenceReturningTest(TestBase): class KeyReturningTest(TestBase, AssertsExecutionResults): """test returning() works with columns that define 'key'.""" - __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access', 'oracle+zxjdbc') + __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access') def setup(self): meta = MetaData(testing.db) -- 2.47.2