]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- very rudimental support for OUT parameters added; use sql.outparam(name, type)
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 26 Jul 2007 22:09:52 +0000 (22:09 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 26 Jul 2007 22:09:52 +0000 (22:09 +0000)
    to set up an OUT parameter, just like bindparam(); after execution, values are
    avaiable via result.out_parameters dictionary. [ticket:507]
- dialect.get_type_map() apparently never worked, not sure why unit test seemed
to work the first time around.
- OracleText doesn't seem to return cx_oracle.LOB.

CHANGES
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql.py
test/dialect/alltests.py
test/dialect/oracle.py [new file with mode: 0644]
test/sql/query.py
test/sql/testtypes.py

diff --git a/CHANGES b/CHANGES
index 01c245038ba99cfcb1bd906d411e3b921881ffe1..ec8d8fcce12317deb0406af26343444948c7f783 100644 (file)
--- a/CHANGES
+++ b/CHANGES
     from SelectResults isn't present anymore, need to use join(). 
 - postgres
   - Added PGArray datatype for using postgres array datatypes
+- oracle
+  - very rudimental support for OUT parameters added; use sql.outparam(name, type)
+    to set up an OUT parameter, just like bindparam(); after execution, values are
+    avaiable via result.out_parameters dictionary. [ticket:507]
 
 0.3.11
 - orm
index 38a7b50de81fa5f1177b82d7e3ac0adac4fb501a..cdf14c9fa739c6ffa3bfa2a02f4e0dadac7c4ec7 100644 (file)
@@ -85,11 +85,11 @@ class OracleText(sqltypes.TEXT):
     def get_col_spec(self):
         return "CLOB"
 
-    def convert_result_value(self, value, dialect):
-        if value is None:
-            return None
-        else:
-            return super(OracleText, self).convert_result_value(value.read(), dialect)
+   # def convert_result_value(self, value, dialect):
+   #     if value is None:
+   #         return None
+   #     else:
+   #         return super(OracleText, self).convert_result_value(value.read(), dialect)
 
 
 class OracleRaw(sqltypes.Binary):
@@ -178,8 +178,26 @@ class OracleExecutionContext(default.DefaultExecutionContext):
         super(OracleExecutionContext, self).pre_exec()
         if self.dialect.auto_setinputsizes:
             self.set_input_sizes()
+        if self.compiled_parameters is not None and not isinstance(self.compiled_parameters, list):
+            for key in self.compiled_parameters:
+                (bindparam, name, value) = self.compiled_parameters.get_parameter(key)
+                if bindparam.isoutparam:
+                    dbtype = bindparam.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
+                    if not hasattr(self, 'out_parameters'):
+                        self.out_parameters = {}
+                    self.out_parameters[name] = self.cursor.var(dbtype)
+                    self.parameters[name] = self.out_parameters[name]
 
     def get_result_proxy(self):
+        if hasattr(self, 'out_parameters'):
+            if self.compiled_parameters is not None:
+                 for k in self.out_parameters:
+                     type = self.compiled_parameters.get_type(k)
+                     self.out_parameters[k] = type.dialect_impl(self.dialect).convert_result_value(self.out_parameters[k].getvalue(), self.dialect)
+            else:
+                 for k in self.out_parameters:
+                     self.out_parameters[k] = self.out_parameters[k].getvalue()
+
         if self.cursor.description is not None:
             for column in self.cursor.description:
                 type_code = column[1]
index ad029425565606875abe43418f1273f429dde323..5a2de338953d6b241f015e4627ec5cff1f294169 100644 (file)
@@ -1102,6 +1102,7 @@ class ResultProxy(object):
             return self.context.get_rowcount()
     rowcount = property(_get_rowcount)
     lastrowid = property(lambda s:s.cursor.lastrowid)
+    out_parameters = property(lambda s:s.context.out_parameters)
     
     def _init_metadata(self):
         if hasattr(self, '_ResultProxy__props'):
index a87a2e01704dbbb4c17de1e991b36b0fe60f0c3f..03af9272acbe54be7790905c7370dfb2445095d4 100644 (file)
@@ -33,8 +33,9 @@ class DefaultDialect(base.Dialect):
         dialect_module = sys.modules[self.__class__.__module__]
         map = {}
         for obj in dialect_module.__dict__.values():
-            if isinstance(obj, types.TypeEngine):
-                map[obj().get_dbapi_type(self.dialect)] = obj
+            if isinstance(obj, type) and issubclass(obj, types.TypeEngine):
+                obj = obj()
+                map[obj.get_dbapi_type(self.dbapi)] = obj
         self._dbapi_type_map = map
     
     def decode_result_columnname(self, name):
index a40af7d6d0a1712f61fceead02241c25160b04d5..01588e92da1a98f32347fa2bcdff00a44f7974fd 100644 (file)
@@ -35,7 +35,7 @@ __all__ = ['Alias', 'ClauseElement', 'ClauseParameters',
            'between', 'bindparam', 'case', 'cast', 'column', 'delete',
            'desc', 'distinct', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier',
            'insert', 'intersect', 'intersect_all', 'join', 'literal',
-           'literal_column', 'not_', 'null', 'or_', 'outerjoin', 'select',
+           'literal_column', 'not_', 'null', 'or_', 'outparam', 'outerjoin', 'select',
            'subquery', 'table', 'text', 'union', 'union_all', 'update',]
 
 BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
@@ -681,7 +681,7 @@ def outparam(key, type_=None):
     attribute, which returns a dictionary containing the values.
     """
     
-    return _BindParamClause(key, type_=type_, unique=False, isoutparam=True)
+    return _BindParamClause(key, None, type_=type_, unique=False, isoutparam=True)
     
 def text(text, bind=None, *args, **kwargs):
     """Create literal text to be inserted into a query.
@@ -808,6 +808,9 @@ class ClauseParameters(object):
         self.__binds = {}
         self.positional = positional or []
 
+    def get_parameter(self, key):
+        return self.__binds[key]
+
     def set_parameter(self, bindparam, value, name):
         self.__binds[name] = [bindparam, name, value]
         
@@ -823,6 +826,9 @@ class ClauseParameters(object):
    
     def keys(self):
         return self.__binds.keys()
+
+    def __iter__(self):
+        return iter(self.keys())
  
     def __getitem__(self, key):
         return self.get_processed(key)
index f4b39dd6f363edccdd5baa194f4bb79dad74bce6..8900736259cd9254250e3fffbe0ed4cb93a84f89 100644 (file)
@@ -5,6 +5,7 @@ def suite():
     modules_to_test = (
         'dialect.mysql',
         'dialect.postgres',
+        'dialect.oracle',
         )
     alltests = unittest.TestSuite()
     for name in modules_to_test:
diff --git a/test/dialect/oracle.py b/test/dialect/oracle.py
new file mode 100644 (file)
index 0000000..c45d482
--- /dev/null
@@ -0,0 +1,31 @@
+import testbase, testing
+from sqlalchemy import *
+from sqlalchemy.databases import mysql
+from testlib import *
+
+
+class OutParamTest(AssertMixin):
+    @testing.supported('oracle')
+    def setUpAll(self):
+        testbase.db.execute("""
+create or replace procedure foo(x_in IN number, x_out OUT number, y_out OUT number) IS
+  retval number;
+    begin
+    retval := 6;
+    x_out := 10;
+    y_out := x_in * 15;
+    end;
+        """)
+
+    @testing.supported('oracle')
+    def test_out_params(self):
+        result = testbase.db.execute(text("begin foo(:x, :y, :z); end;", bindparams=[bindparam('x', Numeric), outparam('y', Numeric), outparam('z', Numeric)]), x=5)
+        assert result.out_parameters == {'y':10, 'z':75}, result.out_parameters
+        print result.out_parameters
+
+    @testing.supported('oracle')
+    def tearDownAll(self):
+         testbase.db.execute("DROP PROCEDURE foo")
+
+if __name__ == '__main__':
+    testbase.main()
index bfad2c26715c27bfd6112023dbdb4b1d0712d8af..48a28a9a56ef6e071680d80fc5decc30d15e45cf 100644 (file)
@@ -58,14 +58,15 @@ class QueryTest(PersistTest):
             if result.lastrow_has_defaults():
                 criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())])
                 row = table.select(criterion).execute().fetchone()
-                ret.update(row)
+                for c in table.c:
+                    ret[c.key] = row[c]
             return ret
 
         for supported, table, values, assertvalues in [
             (
                 {'unsupported':['sqlite']},
                 Table("t1", metadata, 
-                    Column('id', Integer, primary_key=True),
+                    Column('id', Integer, Sequence('t1_id_seq', optional=True), primary_key=True),
                     Column('foo', String(30), primary_key=True)),
                 {'foo':'hi'},
                 {'id':1, 'foo':'hi'}
@@ -73,7 +74,7 @@ class QueryTest(PersistTest):
             (
                 {'unsupported':['sqlite']},
                 Table("t2", metadata, 
-                    Column('id', Integer, primary_key=True),
+                    Column('id', Integer, Sequence('t2_id_seq', optional=True), primary_key=True),
                     Column('foo', String(30), primary_key=True),
                     Column('bar', String(30), PassiveDefault('hi'))
                 ),
@@ -93,7 +94,7 @@ class QueryTest(PersistTest):
             (
                 {'unsupported':[]},
                 Table("t4", metadata, 
-                    Column('id', Integer, primary_key=True),
+                    Column('id', Integer, Sequence('t4_id_seq', optional=True), primary_key=True),
                     Column('foo', String(30), primary_key=True),
                     Column('bar', String(30), PassiveDefault('hi'))
                 ),
index 6b2c98d79b1abb412a6b87246d56c6860103947f..6590330164eaa0c2f7c101106439bcf93757533c 100644 (file)
@@ -382,8 +382,13 @@ class NumericTest(AssertMixin):
         from decimal import Decimal
         numeric_table.insert().execute(numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.78)
         numeric_table.insert().execute(numericcol=Decimal("3.5"), floatcol=Decimal("5.6"), ncasdec=Decimal("12.4"), fcasdec=Decimal("15.78"))
-        print numeric_table.select().execute().fetchall()
-        assert numeric_table.select().execute().fetchall() == [
+        l = numeric_table.select().execute().fetchall()
+        print l
+        rounded = [
+            (l[0][0], l[0][1], round(l[0][2], 5), l[0][3], l[0][4]),
+            (l[1][0], l[1][1], round(l[1][2], 5), l[1][3], l[1][4]),
+        ]
+        assert rounded == [
             (1, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")),
             (2, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")),
         ]