]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Refinements for maxdb's handling of SERIAL and FIXED columns
authorJason Kirtland <jek@discorporate.us>
Wed, 31 Oct 2007 00:28:53 +0000 (00:28 +0000)
committerJason Kirtland <jek@discorporate.us>
Wed, 31 Oct 2007 00:28:53 +0000 (00:28 +0000)
- Expanded maxdb's set of paren-less functions

lib/sqlalchemy/databases/maxdb.py
test/dialect/maxdb.py
test/sql/testtypes.py

index fcf04bec909d81be713a2cdb5c9763acf7d787c3..de72295b0bcdb34b5e9816567f43eba86698c0aa 100644 (file)
@@ -157,13 +157,21 @@ class MaxSmallInteger(MaxInteger):
 
 
 class MaxNumeric(sqltypes.Numeric):
-    """The NUMERIC (also FIXED, DECIMAL) data type."""
+    """The FIXED (also NUMERIC, DECIMAL) data type."""
+
+    def __init__(self, precision=None, length=None, **kw):
+        kw.setdefault('asdecimal', True)
+        super(MaxNumeric, self).__init__(length=length, precision=precision,
+                                         **kw)
+
+    def bind_processor(self, dialect):
+        return None
 
     def get_col_spec(self):
         if self.length and self.precision:
-            return 'NUMERIC(%s, %s)' % (self.precision, self.length)
-        elif self.length:
-            return 'NUMERIC(%s)' % self.length
+            return 'FIXED(%s, %s)' % (self.precision, self.length)
+        elif self.precision:
+            return 'FIXED(%s)' % self.precision
         else:
             return 'INTEGER'
 
@@ -344,20 +352,21 @@ colspecs = {
 
 ischema_names = { 
     'boolean': MaxBoolean,
-    'int': MaxInteger,
-    'integer': MaxInteger,
-    'varchar': MaxString,
     'char': MaxChar,
     'character': MaxChar,
+    'date': MaxDate,
     'fixed': MaxNumeric,
     'float': MaxFloat,
-    'long': MaxText,
+    'int': MaxInteger,
+    'integer': MaxInteger,
     'long binary': MaxBlob,
     'long unicode': MaxText,
     'long': MaxText,
+    'long': MaxText,
+    'smallint': MaxSmallInteger,
+    'time': MaxTime,
     'timestamp': MaxTimestamp,
-    'date': MaxDate,
-    'time': MaxTime
+    'varchar': MaxString,
     }
 
 
@@ -586,7 +595,7 @@ class MaxDBDialect(default.DefaultDialect):
         include_columns = util.Set(include_columns or [])
         
         for row in rows:
-            (name, mode, col_type, encoding, length, precision,
+            (name, mode, col_type, encoding, length, scale,
              nullable, constant_def, func_def) = row
 
             name = normalize(name)
@@ -596,7 +605,12 @@ class MaxDBDialect(default.DefaultDialect):
 
             type_args, type_kw = [], {}
             if col_type == 'FIXED':
-                type_args = length, precision
+                type_args = length, scale
+                # Convert FIXED(10) DEFAULT SERIAL to our Integer
+                if (scale == 0 and
+                    func_def is not None and func_def.startswith('SERIAL')):
+                    col_type = 'INTEGER'
+                    type_args = length,
             elif col_type in 'FLOAT':
                 type_args = length,
             elif col_type in ('CHAR', 'VARCHAR'):
@@ -620,10 +634,15 @@ class MaxDBDialect(default.DefaultDialect):
 
             if func_def is not None:
                 if func_def.startswith('SERIAL'):
-                    # strip current numbering
-                    col_kw['default'] = schema.PassiveDefault(
-                        sql.text('SERIAL'))
-                    col_kw['autoincrement'] = True
+                    if col_kw['primary_key']:
+                        # No special default- let the standard autoincrement
+                        # support handle SERIAL pk columns.
+                        col_kw['autoincrement'] = True
+                    else:
+                        # strip current numbering
+                        col_kw['default'] = schema.PassiveDefault(
+                            sql.text('SERIAL'))
+                        col_kw['autoincrement'] = True
                 else:
                     col_kw['default'] = schema.PassiveDefault(
                         sql.text(func_def))
@@ -705,8 +724,9 @@ class MaxDBCompiler(compiler.DefaultCompiler):
     # These functions must be written without parens when called with no
     # parameters.  e.g. 'SELECT DATE FROM DUAL' not 'SELECT DATE() FROM DUAL'
     bare_functions = util.Set([
-        'CURRENT_SCHEMA', 'DATE', 'TIME', 'TIMESTAMP', 'TIMEZONE',
-        'TRANSACTION', 'USER', 'UID', 'USERGROUP', 'UTCDATE'])
+        'CURRENT_SCHEMA', 'DATE', 'FALSE', 'SYSDBA', 'TIME', 'TIMESTAMP',
+        'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP',
+        'UTCDATE', 'UTCDIFF'])
     
     def default_from(self):
         return ' FROM DUAL'
index 551c26b3746db235fa0e176f307773d51edde5a7..336986744853f535e0e57ba904c7a67db03acf7d 100644 (file)
@@ -3,7 +3,8 @@
 import testbase
 import StringIO, sys
 from sqlalchemy import *
-from sqlalchemy import sql
+from sqlalchemy import exceptions, sql
+from sqlalchemy.util import Decimal
 from sqlalchemy.databases import maxdb
 from testlib import *
 
@@ -12,12 +13,166 @@ from testlib import *
 # - add "Database" test, a quick check for join behavior on different max versions
 # - full max-specific reflection suite
 # - datetime tests
-# - decimal etc. tests
 # - the orm/query 'test_has' destabilizes the server- cover here
 
-class BasicTest(AssertMixin):
-    def test_import(self):
-        return True
+class ReflectionTest(AssertMixin):
+    """Extra reflection tests."""
+
+    def _test_decimal(self, tabledef):
+        """Checks a variety of FIXED usages.
+
+        This is primarily for SERIAL columns, which can be FIXED (scale-less)
+        or (SMALL)INT.  Ensures that FIXED id columns are converted to
+        integers and that are assignable as such.  Also exercises general
+        decimal assignment and selection behavior.
+        """
+
+        meta = MetaData(testbase.db)
+        try:
+            if isinstance(tabledef, basestring):
+                # run textual CREATE TABLE
+                testbase.db.execute(tabledef)
+            else:
+                _t = tabledef.tometadata(meta)
+                _t.create()
+            t = Table('dectest', meta, autoload=True)
+
+            vals = [Decimal('2.2'), Decimal('23'), Decimal('2.4'), 25]
+            cols = ['d1','d2','n1','i1']
+            t.insert().execute(dict(zip(cols,vals)))
+            roundtrip = list(t.select().execute())
+            self.assertEquals(roundtrip, [tuple([1] + vals)])
+
+            t.insert().execute(dict(zip(['id'] + cols,
+                                        [2] + list(roundtrip[0][1:]))))
+            roundtrip2 = list(t.select(order_by=t.c.id).execute())
+            self.assertEquals(roundtrip2, [tuple([1] + vals),
+                                           tuple([2] + vals)])
+        finally:
+            try:
+                testbase.db.execute("DROP TABLE dectest")
+            except exceptions.DatabaseError:
+                pass
+
+    @testing.supported('maxdb')
+    def test_decimal_fixed_serial(self):
+        tabledef = """
+        CREATE TABLE dectest (
+          id FIXED(10) DEFAULT SERIAL PRIMARY KEY,
+          d1 FIXED(10,2),
+          d2 FIXED(12),
+          n1 NUMERIC(12,2),
+          i1 INTEGER)
+          """
+        return self._test_decimal(tabledef)
+
+    @testing.supported('maxdb')
+    def test_decimal_integer_serial(self):
+        tabledef = """
+        CREATE TABLE dectest (
+          id INTEGER DEFAULT SERIAL PRIMARY KEY,
+          d1 DECIMAL(10,2),
+          d2 DECIMAL(12),
+          n1 NUMERIC(12,2),
+          i1 INTEGER)
+          """
+        return self._test_decimal(tabledef)
+
+    @testing.supported('maxdb')
+    def test_decimal_implicit_serial(self):
+        tabledef = """
+        CREATE TABLE dectest (
+          id SERIAL PRIMARY KEY,
+          d1 FIXED(10,2),
+          d2 FIXED(12),
+          n1 NUMERIC(12,2),
+          i1 INTEGER)
+          """
+        return self._test_decimal(tabledef)
+
+    @testing.supported('maxdb')
+    def test_decimal_smallint_serial(self):
+        tabledef = """
+        CREATE TABLE dectest (
+          id SMALLINT DEFAULT SERIAL PRIMARY KEY,
+          d1 FIXED(10,2),
+          d2 FIXED(12),
+          n1 NUMERIC(12,2),
+          i1 INTEGER)
+          """
+        return self._test_decimal(tabledef)
+
+    @testing.supported('maxdb')
+    def test_decimal_sa_types_1(self):
+        tabledef = Table('dectest', MetaData(),
+                         Column('id', Integer, primary_key=True),
+                         Column('d1', DECIMAL(10, 2)),
+                         Column('d2', DECIMAL(12)),
+                         Column('n1', NUMERIC(12,2)),
+                         Column('i1', Integer))
+        return self._test_decimal(tabledef)
+
+    @testing.supported('maxdb')
+    def test_decimal_sa_types_2(self):
+        tabledef = Table('dectest', MetaData(),
+                         Column('id', Integer, primary_key=True),
+                         Column('d1', maxdb.MaxNumeric(10, 2)),
+                         Column('d2', maxdb.MaxNumeric(12)),
+                         Column('n1', maxdb.MaxNumeric(12,2)),
+                         Column('i1', Integer))
+        return self._test_decimal(tabledef)
+
+    @testing.supported('maxdb')
+    def test_decimal_sa_types_3(self):
+        tabledef = Table('dectest', MetaData(),
+                         Column('id', Integer, primary_key=True),
+                         Column('d1', maxdb.MaxNumeric(10, 2)),
+                         Column('d2', maxdb.MaxNumeric),
+                         Column('n1', maxdb.MaxNumeric(12,2)),
+                         Column('i1', Integer))
+        return self._test_decimal(tabledef)
+
+    @testing.supported('maxdb')
+    def test_assorted_type_aliases(self):
+        """Ensures that aliased types are reflected properly."""
+
+        meta = MetaData(testbase.db)
+        try:
+            testbase.db.execute("""
+            CREATE TABLE assorted (
+              c1 INT,
+              c2 BINARY(2),
+              c3 DEC(4,2),
+              c4 DEC(4),
+              c5 DEC,
+              c6 DOUBLE PRECISION,
+              c7 NUMERIC(4,2),
+              c8 NUMERIC(4),
+              c9 NUMERIC,
+              c10 REAL(4),
+              c11 REAL,
+              c12 CHARACTER(2))
+              """)
+            table = Table('assorted', meta, autoload=True)
+            expected = [maxdb.MaxInteger,
+                        maxdb.MaxNumeric,
+                        maxdb.MaxNumeric,
+                        maxdb.MaxNumeric,
+                        maxdb.MaxNumeric,
+                        maxdb.MaxFloat,
+                        maxdb.MaxNumeric,
+                        maxdb.MaxNumeric,
+                        maxdb.MaxNumeric,
+                        maxdb.MaxFloat,
+                        maxdb.MaxFloat,
+                        maxdb.MaxChar,]
+            for i, col in enumerate(table.columns):
+                self.assert_(isinstance(col.type, expected[i]))
+        finally:
+            try:
+                testbase.db.execute("DROP TABLE assorted")
+            except exceptions.DatabaseError:
+                pass
 
 class DBAPITest(AssertMixin):
     """Asserts quirks in the native Python DB-API driver.
index 101efb79ba8a789d3493631951a494b227f9cd77..4af96d57fe96e5f520866051697cc15768a425d7 100644 (file)
@@ -220,12 +220,16 @@ class ColumnsTest(AssertMixin):
                             'smallint_column': 'smallint_column SMALLINT',
                             'varchar_column': 'varchar_column VARCHAR(20)',
                             'numeric_column': 'numeric_column NUMERIC(12, 3)',
-                            'float_column': 'float_column NUMERIC(25, 2)'
+                            'float_column': 'float_column FLOAT(25)',
                           }
 
         db = testbase.db
-        if not db.name=='sqlite' and not db.name=='oracle':
-            expectedResults['float_column'] = 'float_column FLOAT(25)'
+        if testing.against('sqlite', 'oracle'):
+            expectedResults['float_column'] = 'float_column NUMERIC(25, 2)'
+
+        if testing.against('maxdb'):
+            expectedResults['numeric_column'] = (
+                expectedResults['numeric_column'].replace('NUMERIC', 'FIXED'))
 
         print db.engine.__module__
         testTable = Table('testColumns', MetaData(db),
@@ -237,7 +241,10 @@ class ColumnsTest(AssertMixin):
         )
 
         for aCol in testTable.c:
-            self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db.dialect, db, None, None).get_column_specification(aCol))
+            self.assertEquals(
+                expectedResults[aCol.name],
+                db.dialect.schemagenerator(db.dialect, db, None, None).\
+                  get_column_specification(aCol))
 
 class UnicodeTest(AssertMixin):
     """tests the Unicode type.  also tests the TypeDecorator with instances in the types package."""