]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix up oracle tests, returning is on by default
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Aug 2009 16:49:28 +0000 (16:49 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Aug 2009 16:49:28 +0000 (16:49 +0000)
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/engine/base.py
test/dialect/test_oracle.py
test/sql/test_returning.py

index 9ba01610157ac35731d737932b2c68d90beb8a21..17b09e79cc9c491214bc1d578dd5f9d83cc74233 100644 (file)
@@ -318,7 +318,7 @@ class OracleCompiler(compiler.SQLCompiler):
         columnlist = list(expression._select_iterables(returning_cols))
         
         # within_columns_clause =False so that labels (foo AS bar) don't render
-        columns = [self.process(c, within_columns_clause=False) for c in columnlist]
+        columns = [self.process(c, within_columns_clause=False, result_map=self.result_map) for c in columnlist]
         
         binds = [create_out_param(c, i) for i, c in enumerate(columnlist)]
         
index d8a0c445a42e08ee1981313e02b358904a6c6945..475d6559aa84172d84a673ba044f497ccd47894e 100644 (file)
@@ -20,6 +20,9 @@ URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are:
 
 * *allow_twophase* - enable two-phase transactions.  Defaults to ``True``.
 
+* *arraysize* - set the cx_oracle.arraysize value on cursors, in SQLAlchemy
+  it defaults to 50.  See the section on "LOB Objects" below.
+  
 * *auto_convert_lobs* - defaults to True, see the section on LOB objects.
 
 * *auto_setinputsizes* - the cx_oracle.setinputsizes() call is issued for all bind parameters.
@@ -32,6 +35,14 @@ URL, or as keyword arguments to :func:`~sqlalchemy.create_engine()` are:
 * *threaded* - enable multithreaded access to cx_oracle connections.  Defaults
   to ``True``.  Note that this is the opposite default of cx_oracle itself.
 
+Unicode
+-------
+
+As of cx_oracle 5, Python unicode objects can be bound directly to statements, 
+and it appears that cx_oracle can handle these even without NLS_LANG being set.
+SQLAlchemy tests for version 5 and will pass unicode objects straight to cx_oracle
+if this is the case.  For older versions of cx_oracle, SQLAlchemy will encode bind
+parameters normally using dialect.encoding as the encoding.
 
 LOB Objects
 -----------
@@ -47,7 +58,7 @@ The size of a "batch of rows" is controlled by the cursor.arraysize value, which
 defaults to 50 (cx_oracle normally defaults this to one).  
 
 Secondly, the LOB object is not a standard DBAPI return value so SQLAlchemy seeks to 
-"normalize" the results to look more like other DBAPIs.
+"normalize" the results to look more like that of other DBAPIs.
 
 The conversion of LOB objects by this dialect is unique in SQLAlchemy in that it takes place
 for all statement executions, even plain string-based statements for which SQLA has no awareness
@@ -58,8 +69,8 @@ of LOB objects, can be disabled using auto_convert_lobs=False.
 Two Phase Transaction Support
 -----------------------------
 
-Two Phase transactions are implemented using XA transactions.  Success has been reported of them
-working successfully but this should be regarded as an experimental feature.
+Two Phase transactions are implemented using XA transactions.  Success has been reported 
+with this feature but it should be regarded as experimental.
 
 """
 
@@ -137,7 +148,14 @@ class _OracleUnicodeText(_LOBMixin, sqltypes.UnicodeText):
     def get_dbapi_type(self, dbapi):
         return dbapi.NCLOB
 
-
+class _OracleInteger(sqltypes.Integer):
+    def result_processor(self, dialect):
+        def to_int(val):
+            if val is not None:
+                val = int(val)
+            return val
+        return to_int
+        
 class _OracleBinary(_LOBMixin, sqltypes.Binary):
     def get_dbapi_type(self, dbapi):
         return dbapi.BLOB
@@ -158,6 +176,8 @@ colspecs = {
     sqltypes.Text : _OracleText,
     sqltypes.UnicodeText : _OracleUnicodeText,
     sqltypes.TIMESTAMP : _OracleTimestamp,
+    sqltypes.Integer : _OracleInteger,  # this is only needed for OUT parameters.
+                                        # it would be nice if we could not use it otherwise.
     oracle.RAW: _OracleRaw,
 }
 
@@ -194,7 +214,6 @@ class Oracle_cx_oracleExecutionContext(DefaultExecutionContext):
                     self.out_parameters[name] = self.cursor.var(dbtype)
                     self.parameters[0][quoted_bind_names.get(name, name)] = self.out_parameters[name]
         
-        
     def create_cursor(self):
         c = self._connection.connection.cursor()
         if self.dialect.arraysize:
@@ -202,50 +221,57 @@ class Oracle_cx_oracleExecutionContext(DefaultExecutionContext):
         return c
 
     def get_result_proxy(self):
+        if hasattr(self, 'out_parameters') and self.compiled.returning:
+            returning_params = dict((k, v.getvalue()) for k, v in self.out_parameters.items())
+            return ReturningResultProxy(self, returning_params)
+
+        result = None
+        if self.cursor.description is not None:
+            for column in self.cursor.description:
+                type_code = column[1]
+                if type_code in self.dialect.ORACLE_BINARY_TYPES:
+                    result = base.BufferedColumnResultProxy(self)
+        
+        if result is None:
+            result = base.ResultProxy(self)
+            
         if hasattr(self, 'out_parameters'):
             if self.compiled_parameters is not None and len(self.compiled_parameters) == 1:
+                result.out_parameters = out_parameters = {}
+                
                 for bind, name in self.compiled.bind_names.iteritems():
                     if name in self.out_parameters:
                         type = bind.type
                         result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect)
                         if result_processor is not None:
-                            self.out_parameters[name] = result_processor(self.out_parameters[name].getvalue())
+                            out_parameters[name] = result_processor(self.out_parameters[name].getvalue())
                         else:
-                            self.out_parameters[name] = self.out_parameters[name].getvalue()
+                            out_parameters[name] = self.out_parameters[name].getvalue()
             else:
-                for k in self.out_parameters:
-                    self.out_parameters[k] = self.out_parameters[k].getvalue()
+                result.out_parameters = dict((k, v.getvalue()) for k, v in self.out_parameters.items())
 
-        if self.cursor.description is not None:
-            for column in self.cursor.description:
-                type_code = column[1]
-                if type_code in self.dialect.ORACLE_BINARY_TYPES:
-                    return base.BufferedColumnResultProxy(self)
-        
-        if hasattr(self, 'out_parameters') and \
-            self.compiled.returning:
-                
-            return ReturningResultProxy(self)
-        else:
-            return base.ResultProxy(self)
+        return result
 
 class ReturningResultProxy(base.FullyBufferedResultProxy):
     """Result proxy which stuffs the _returning clause + outparams into the fetch."""
     
+    def __init__(self, context, returning_params):
+        self._returning_params = returning_params
+        super(ReturningResultProxy, self).__init__(context)
+        
     def _cursor_description(self):
         returning = self.context.compiled.returning
         
         ret = []
         for c in returning:
-            if hasattr(c, 'key'):
-                ret.append((c.key, c.type))
+            if hasattr(c, 'name'):
+                ret.append((c.name, c.type))
             else:
                 ret.append((c.anon_label, c.type))
         return ret
     
     def _buffer_rows(self):
-        returning = self.context.compiled.returning
-        return [tuple(self.context.out_parameters["ret_%d" % i] for i, c in enumerate(returning))]
+        return [tuple(self._returning_params["ret_%d" % i] for i, c in enumerate(self._returning_params))]
 
 class Oracle_cx_oracle(OracleDialect):
     execution_ctx_cls = Oracle_cx_oracleExecutionContext
index 0a0b0ff0ca4bb4651f1975810c591bb3d9ba8bd1..e126cec6818373b37578215bc99261e14304e080 100644 (file)
@@ -1591,6 +1591,7 @@ class ResultProxy(object):
     """
 
     _process_row = RowProxy
+    out_parameters = None
     
     def __init__(self, context):
         self.context = context
@@ -1639,10 +1640,6 @@ class ResultProxy(object):
         """
         return self.cursor.lastrowid
 
-    @property
-    def out_parameters(self):
-        return self.context.out_parameters
-    
     def _cursor_description(self):
         return self.cursor.description
             
index 53e0f9ec2f4c0da214edc155f9a1198006f051f3..444f24cf28350d536249b8413eed7a35ee524051 100644 (file)
@@ -30,8 +30,10 @@ create or replace procedure foo(x_in IN number, x_out OUT number, y_out OUT numb
         """)
 
     def test_out_params(self):
-        result = testing.db.execute(text("begin foo(:x_in, :x_out, :y_out, :z_out); end;", bindparams=[bindparam('x_in', Numeric), outparam('x_out', Numeric), outparam('y_out', Numeric), outparam('z_out', String)]), x_in=5)
+        result = testing.db.execute(text("begin foo(:x_in, :x_out, :y_out, :z_out); end;", 
+                bindparams=[bindparam('x_in', Numeric), outparam('x_out', Integer), outparam('y_out', Numeric), outparam('z_out', String)]), x_in=5)
         assert result.out_parameters == {'x_out':10, 'y_out':75, 'z_out':None}, result.out_parameters
+        assert isinstance(result.out_parameters['x_out'], int)
 
     @classmethod
     def teardown_class(cls):
@@ -362,7 +364,22 @@ class TypesTest(TestBase, AssertsCompiledSQL):
         ]:
             assert isinstance(start.dialect_impl(dialect), test), "wanted %r got %r" % (test, start.dialect_impl(dialect))
 
-
+    def test_int_not_float(self):
+        m = MetaData(testing.db)
+        t1 = Table('t1', m, Column('foo', Integer))
+        t1.create()
+        try:
+            r = t1.insert().values(foo=5).returning(t1.c.foo).execute()
+            x = r.scalar()
+            assert x == 5
+            assert isinstance(x, int)
+
+            x = t1.select().scalar()
+            assert x == 5
+            assert isinstance(x, int)
+        finally:
+            t1.drop()
+        
     def test_reflect_raw(self):
         types_table = Table(
         'all_types', MetaData(testing.db),
@@ -417,6 +434,8 @@ class TypesTest(TestBase, AssertsCompiledSQL):
             eq_(row['bindata'].read(), 'this is binary')
         finally:
             t.drop(engine)
+            
+            
 class BufferedColumnTest(TestBase, AssertsCompiledSQL):
     __only_on__ = 'oracle'
 
@@ -448,7 +467,7 @@ class BufferedColumnTest(TestBase, AssertsCompiledSQL):
     @testing.fails_on('+zxjdbc', 'FIXME: zxjdbc should support this')
     def test_fetch_single_arraysize(self):
         eng = testing_engine(options={'arraysize':1})
-        result = eng.execute(binary_table.select()).fetchall(),
+        result = eng.execute(binary_table.select()).fetchall()
         if jython:
             result = [(i, value.tostring()) for i, value in result]
         eq_(result, [(i, stream) for i in range(1, 11)])
index e076f3fe7c88cb045f9002910b8fe5089db80eab..1b69c55ffcd25cc7d67dd5968a37537cbe4c8496 100644 (file)
@@ -4,7 +4,6 @@ from sqlalchemy.test import *
 from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.types import TypeDecorator
 
-        
 class ReturningTest(TestBase, AssertsExecutionResults):
     __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access')
 
@@ -35,7 +34,7 @@ class ReturningTest(TestBase, AssertsExecutionResults):
     
     def teardown(self):
         table.drop()
-
+    
     @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
     @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
     def test_column_targeting(self):
@@ -157,3 +156,33 @@ class SequenceReturningTest(TestBase):
         r = table.insert().values(data='hi').returning(table.c.id).execute()
         assert r.first() == (1, )
         assert seq.execute() == 2
+
+class KeyReturningTest(TestBase, AssertsExecutionResults):
+    """test returning() works with columns that define 'key'."""
+    
+    __unsupported_on__ = ('sqlite', 'mysql', 'maxdb', 'sybase', 'access')
+
+    def setup(self):
+        meta = MetaData(testing.db)
+        global table
+
+        table = Table('tables', meta,
+            Column('id', Integer, primary_key=True, key='foo_id', test_needs_autoincrement=True),
+            Column('data', String(20)),
+        )
+        table.create(checkfirst=True)
+
+    def teardown(self):
+        table.drop()
+
+    @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
+    @testing.exclude('postgresql', '<', (8, 2), '8.3+ feature')
+    def test_insert(self):
+        result = table.insert().returning(table.c.foo_id).execute(data='somedata')
+        row = result.first()
+        assert row[table.c.foo_id] == row['id'] == 1
+        
+        result = table.select().execute().first()
+        assert row[table.c.foo_id] == row['id'] == 1
+        
+