]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the Oracle dialect now features NUMBER which intends
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 Aug 2009 23:46:06 +0000 (23:46 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 Aug 2009 23:46:06 +0000 (23:46 +0000)
to act justlike Oracle's NUMBER type.  It is the primary
numeric type returned by table reflection and attempts
to return Decimal()/float/int based on the precision/scale
parameters.  [ticket:885]

CHANGES
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/sql/compiler.py
test/dialect/test_oracle.py
test/sql/test_types.py

diff --git a/CHANGES b/CHANGES
index cc005a27122a98f3643821532fe166f0393cba28..27a38145e6162a82bb248321bc847755c20717be 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -253,6 +253,12 @@ CHANGES
 
     - an NCLOB type is added to the base types.
 
+    - the Oracle dialect now features NUMBER which intends
+      to act justlike Oracle's NUMBER type.  It is the primary
+      numeric type returned by table reflection and attempts
+      to return Decimal()/float/int based on the precision/scale
+      parameters.  [ticket:885]
+      
     - func.char_length is a generic function for LENGTH
 
     - ForeignKey() which includes onupdate=<value> will emit a
index 17b09e79cc9c491214bc1d578dd5f9d83cc74233..a5ced0738ac9d81ebd40e89afcfd648bd8b5dc05 100644 (file)
@@ -132,14 +132,26 @@ class NCLOB(sqltypes.Text):
 VARCHAR2 = VARCHAR
 NVARCHAR2 = NVARCHAR
 
-class NUMBER(sqltypes.Numeric):
+
+class NUMBER(sqltypes.Numeric, sqltypes.Integer):
     __visit_name__ = 'NUMBER'
     
-class BFILE(sqltypes.Binary):
-    __visit_name__ = 'BFILE'
-
+    def __init__(self, precision=None, scale=None, asdecimal=None):
+        if asdecimal is None:
+            asdecimal = bool(scale and scale > 0)
+                
+        super(NUMBER, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal)
+            
 class DOUBLE_PRECISION(sqltypes.Numeric):
     __visit_name__ = 'DOUBLE_PRECISION'
+    def __init__(self, precision=None, scale=None, asdecimal=None):
+        if asdecimal is None:
+            asdecimal = False
+                
+        super(DOUBLE_PRECISION, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal)
+
+class BFILE(sqltypes.Binary):
+    __visit_name__ = 'BFILE'
 
 class LONG(sqltypes.Text):
     __visit_name__ = 'LONG'
@@ -200,13 +212,24 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
         return self.visit_DATE(type_)
     
     def visit_float(self, type_):
-        if type_.precision is None:
-            return "NUMERIC"
-        else:
-            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : 2}
+        return self.visit_FLOAT(type_)
         
     def visit_unicode(self, type_):
         return self.visit_NVARCHAR(type_)
+    def visit_DOUBLE_PRECISION(self, type_):
+        return self._generate_numeric(type_, "DOUBLE PRECISION")
+        
+    def visit_NUMBER(self, type_):
+        return self._generate_numeric(type_, "NUMBER")
+    
+    def _generate_numeric(self, type_, name):
+        if type_.precision is None:
+            return name
+        elif type_.scale is None:
+            return "%(name)s(%(precision)s)" % {'name':name,'precision': type_.precision}
+        else:
+            return "%(name)s(%(precision)s, %(scale)s)" % {'name':name,'precision': type_.precision, 'scale' : type_.scale}
         
     def visit_VARCHAR(self, type_):
         return "VARCHAR(%(length)s)" % {'length' : type_.length}
@@ -658,18 +681,8 @@ class OracleDialect(default.DefaultDialect):
             (colname, coltype, length, precision, scale, nullable, default) = \
                 (self.normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
 
-            # INTEGER if the scale is 0 and precision is null
-            # NUMBER if the scale and precision are both null
-            # NUMBER(9,2) if the precision is 9 and the scale is 2
-            # NUMBER(3) if the precision is 3 and scale is 0
-            #length is ignored except for CHAR and VARCHAR2
             if coltype == 'NUMBER' :
-                if precision is None and scale is None:
-                    coltype = sqltypes.NUMERIC
-                elif precision is None and scale == 0:
-                    coltype = sqltypes.INTEGER
-                else :
-                    coltype = sqltypes.NUMERIC(precision, scale)
+                coltype = NUMBER(precision, scale)
             elif coltype=='CHAR' or coltype=='VARCHAR2':
                 coltype = self.ischema_names.get(coltype)(length)
             else:
index 475d6559aa84172d84a673ba044f497ccd47894e..f4092359142f30d47d41735b90a9d3b2fc88c4e0 100644 (file)
@@ -178,6 +178,7 @@ colspecs = {
     sqltypes.TIMESTAMP : _OracleTimestamp,
     sqltypes.Integer : _OracleInteger,  # this is only needed for OUT parameters.
                                         # it would be nice if we could not use it otherwise.
+    oracle.NUMBER : oracle.NUMBER, # don't let this get converted
     oracle.RAW: _OracleRaw,
 }
 
index d6187bcde822b0c1c961eb5e2a0eb94c1cf5da27..403ec968bbcd18afe48218eb219ee3006df62762 100644 (file)
@@ -1149,6 +1149,8 @@ class GenericTypeCompiler(engine.TypeCompiler):
     def visit_NUMERIC(self, type_):
         if type_.precision is None:
             return "NUMERIC"
+        elif type_.scale is None:
+            return "NUMERIC(%(precision)s)" % {'precision': type_.precision}
         else:
             return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': type_.precision, 'scale' : type_.scale}
 
index 85c3097be15cc574a7d61d165d7519d08425c142..f8cfdf1fcabc4e194a6afbc80564656509ecb4c0 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy.test.engines import testing_engine
 from sqlalchemy.dialects.oracle import cx_oracle, base as oracle
 from sqlalchemy.engine import default
 from sqlalchemy.util import jython
+from decimal import Decimal
 import os
 
 
@@ -380,7 +381,57 @@ class TypesTest(TestBase, AssertsCompiledSQL):
             assert isinstance(x, int)
         finally:
             t1.drop()
-        
+    
+    def test_numerics(self):
+        m = MetaData(testing.db)
+        t1 = Table('t1', m, 
+            Column('intcol', Integer),
+            Column('numericcol', Numeric(precision=9, scale=2)),
+            Column('floatcol1', Float()),
+            Column('floatcol2', FLOAT()),
+            Column('doubleprec', oracle.DOUBLE_PRECISION),
+            Column('numbercol1', oracle.NUMBER(9)),
+            Column('numbercol2', oracle.NUMBER(9, 3)),
+            Column('numbercol3', oracle.NUMBER),
+            
+        )
+        t1.create()
+        try:
+            t1.insert().execute(
+                intcol=1, 
+                numericcol=5.2, 
+                floatcol1=6.5, 
+                floatcol2 = 8.5,
+                doubleprec = 9.5, 
+                numbercol1=12,
+                numbercol2=14.85,
+                numbercol3=15.76
+                )
+            
+            m2 = MetaData(testing.db)
+            t2 = Table('t1', m2, autoload=True)
+
+            for row in (
+                t1.select().execute().first(),
+                t2.select().execute().first() 
+            ):
+                for i, (val, type_) in enumerate((
+                    (1, int),
+                    (Decimal("5.2"), Decimal),
+                    (6.5, float),
+                    (8.5, float),
+                    (9.5, float),
+                    (12, int),
+                    (Decimal("14.85"), Decimal),
+                    (15.76, float),
+                )):
+                    eq_(row[i], val)
+                    assert isinstance(row[i], type_)
+
+        finally:
+            t1.drop()
+    
+    
     def test_reflect_raw(self):
         types_table = Table(
         'all_types', MetaData(testing.db),
index ccd6e50381320d8f6860c173f631c1ded23178f3..cede11cc5ffbb2782264dfa50a04a6c54bf0bb14 100644 (file)
@@ -205,7 +205,7 @@ class ColumnsTest(TestBase, AssertsExecutionResults):
 
         db = testing.db
         if testing.against('oracle'):
-            expectedResults['float_column'] = 'float_column NUMERIC(25, 2)'
+            expectedResults['float_column'] = 'float_column FLOAT'
 
         if testing.against('sqlite'):
             expectedResults['float_column'] = 'float_column FLOAT'