]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
mssql type fixes....
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 25 Jan 2009 20:52:02 +0000 (20:52 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 25 Jan 2009 20:52:02 +0000 (20:52 +0000)
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/sql/compiler.py
test/dialect/mssql.py

index 1964b6ddc5dad954ed9d012f6632c87de5773077..c9c6fd729ec41a2f53f48c54f54038f47d8054ea 100644 (file)
@@ -258,11 +258,7 @@ class MSNumeric(sqltypes.Numeric):
 
     def bind_processor(self, dialect):
         def process(value):
-            if value is None:
-                # Not sure that this exception is needed
-                return value
-
-            elif isinstance(value, decimal.Decimal):
+            if isinstance(value, decimal.Decimal):
                 if value.adjusted() < 0:
                     result = "%s0.%s%s" % (
                             (value < 0 and '-' or ''),
@@ -309,7 +305,7 @@ class MSTinyInteger(sqltypes.Integer):
 # filter bind parameters into datetime objects (required by pyodbc,
 # not sure about other dialects).
 
-class MSDate(sqltypes.Date):
+class MSDate(sqltypes.DATE):
     def bind_processor(self, dialect):
         def process(value):
             if type(value) == datetime.date:
@@ -329,7 +325,7 @@ class MSDate(sqltypes.Date):
                 return value
         return process
     
-class MSTime(sqltypes.Time):
+class MSTime(sqltypes.TIME):
     def __init__(self, precision=None, **kwargs):
         self.precision = precision
         super(MSTime, self).__init__()
@@ -356,7 +352,7 @@ class MSTime(sqltypes.Time):
                 return value
         return process
 
-class MSDateTime(sqltypes.DateTime):
+class _DateTimeBase(object):
     def bind_processor(self, dialect):
         def process(value):
             if type(value) == datetime.date:
@@ -364,11 +360,14 @@ class MSDateTime(sqltypes.DateTime):
             else:
                 return value
         return process
+
+class MSDateTime(_DateTimeBase, sqltypes.DATETIME):
+    pass
     
-class MSSmallDateTime(MSDateTime):
+class MSSmallDateTime(_DateTimeBase, sqltypes.DateTime):
     __visit_name__ = 'SMALLDATETIME'
 
-class MSDateTime2(MSDateTime):
+class MSDateTime2(_DateTimeBase, sqltypes.DateTime):
     __visit_name__ = 'DATETIME2'
     
     def __init__(self, precision=None, **kwargs):
@@ -549,11 +548,12 @@ class MSNChar(_StringType, sqltypes.NCHAR):
         sqltypes.NCHAR.__init__(self, *args, **kw)
 
 class MSBinary(sqltypes.Binary):
-    pass
+    __visit_name__ = 'BINARY'
 
 class MSVarBinary(sqltypes.Binary):
     __visit_name__ = 'VARBINARY'
 
+        
 class MSImage(sqltypes.Binary):
     __visit_name__ = 'IMAGE'
 
@@ -667,6 +667,8 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
 
     def visit_date(self, type_):
         if self.dialect.server_version_info < MS_2008_VERSION:
+            import pdb
+            pdb.set_trace()
             return self.visit_DATETIME(type_)
         else:
             return self.visit_DATE(type_)
index 7305f497ef7d227e1436a62eb02a86365fa0f420..216c0a76b23a18e73c82f90f6db752c2742b12bd 100644 (file)
@@ -1011,6 +1011,9 @@ class GenericTypeCompiler(engine.TypeCompiler):
         
     def visit_date(self, type_): 
         return self.visit_DATE(type_)
+
+    def visit_big_integer(self, type_): 
+        return self.visit_BIGINT(type_)
         
     def visit_small_integer(self, type_): 
         return self.visit_SMALLINT(type_)
index bebda1752e1c556be986b5f420ebf577d92dcd71..eec5518db229c6a4d8ea071dee0d9092f0668bfe 100755 (executable)
@@ -474,11 +474,11 @@ class ParseConnectTest(TestBase, AssertsCompiledSQL):
         self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
 
 
-class TypesTest(TestBase):
+class TypesTest(TestBase, AssertsExecutionResults):
     __only_on__ = 'mssql'
 
     def setUpAll(self):
-        global numeric_table, metadata
+        global metadata
         metadata = MetaData(testing.db)
 
     def tearDown(self):
@@ -534,11 +534,6 @@ class TypesTest(TestBase):
             raise e
 
 
-class TypesTest2(TestBase, AssertsExecutionResults):
-    "Test Microsoft SQL Server column types"
-
-    __only_on__ = 'mssql'
-
     def test_money(self):
         "Exercise type specification for money types."
 
@@ -550,7 +545,7 @@ class TypesTest2(TestBase, AssertsExecutionResults):
              'SMALLMONEY'),
            ]
 
-        table_args = ['test_mssql_money', MetaData(testing.db)]
+        table_args = ['test_mssql_money', metadata]
         for index, spec in enumerate(columns):
             type_, args, kw, res = spec
             table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
@@ -611,15 +606,14 @@ class TypesTest2(TestBase, AssertsExecutionResults):
 
             ]
 
-        table_args = ['test_mssql_dates', MetaData(testing.db)]
+        table_args = ['test_mssql_dates', metadata]
         for index, spec in enumerate(columns):
             type_, args, kw, res, requires = spec[0:5]
             if (requires and testing._is_excluded('mssql', *requires)) or not requires:
                 table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
 
         dates_table = Table(*table_args)
-        dialect = mssql.dialect()
-        gen = dialect.ddl_compiler(dialect, schema.CreateTable(dates_table))
+        gen = testing.db.dialect.ddl_compiler(testing.db.dialect, schema.CreateTable(dates_table))
 
         for col in dates_table.c:
             index = int(col.name[1:])
@@ -627,62 +621,51 @@ class TypesTest2(TestBase, AssertsExecutionResults):
                            "%s %s" % (col.name, columns[index][3]))
             self.assert_(repr(col))
 
-        try:
-            dates_table.create(checkfirst=True)
-            assert True
-        except:
-            raise
+        dates_table.create(checkfirst=True)
 
         reflected_dates = Table('test_mssql_dates', MetaData(testing.db), autoload=True)
         for col in reflected_dates.c:
             index = int(col.name[1:])
-            testing.eq_(testing.db.dialect.type_descriptor(col.type).__class__,
-                len(columns[index]) > 5 and columns[index][5] or columns[index][0])
-        dates_table.drop()
+            c1 = testing.db.dialect.type_descriptor(col.type).__class__
+            c2 = len(columns[index]) > 5 and columns[index][5] or columns[index][0]
+            assert issubclass(c1, c2), "%r is not a subclass of %r" % (c1, c2)
 
-    def test_dates2(self):
-        meta = MetaData(testing.db)
-        t = Table('test_dates', meta,
+    def test_date_roundtrip(self):
+        t = Table('test_dates', metadata,
                   Column('id', Integer,
                          Sequence('datetest_id_seq', optional=True),
                          primary_key=True),
                   Column('adate', Date),
                   Column('atime', Time),
                   Column('adatetime', DateTime))
-        t.create(checkfirst=True)
-        try:
-            d1 = datetime.date(2007, 10, 30)
-            t1 = datetime.time(11, 2, 32)
-            d2 = datetime.datetime(2007, 10, 30, 11, 2, 32)
-            t.insert().execute(adate=d1, adatetime=d2, atime=t1)
-            t.insert().execute(adate=d2, adatetime=d2, atime=d2)
-
-            x = t.select().execute().fetchall()[0]
-            self.assert_(x.adate.__class__ == datetime.date)
-            self.assert_(x.atime.__class__ == datetime.time)
-            self.assert_(x.adatetime.__class__ == datetime.datetime)
+        metadata.create_all()
+        d1 = datetime.date(2007, 10, 30)
+        t1 = datetime.time(11, 2, 32)
+        d2 = datetime.datetime(2007, 10, 30, 11, 2, 32)
+        t.insert().execute(adate=d1, adatetime=d2, atime=t1)
+        t.insert().execute(adate=d2, adatetime=d2, atime=d2)
 
-            t.delete().execute()
+        x = t.select().execute().fetchall()[0]
+        self.assert_(x.adate.__class__ == datetime.date)
+        self.assert_(x.atime.__class__ == datetime.time)
+        self.assert_(x.adatetime.__class__ == datetime.datetime)
 
-            t.insert().execute(adate=d1, adatetime=d2, atime=t1)
+        t.delete().execute()
 
-            self.assertEquals(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)])
+        t.insert().execute(adate=d1, adatetime=d2, atime=t1)
 
-        finally:
-            t.drop(checkfirst=True)
+        self.assertEquals(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)])
 
     def test_binary(self):
         "Exercise type specification for binary types."
 
         columns = [
             # column type, args, kwargs, expected ddl
-            (types.Binary, [], {},
+            (mssql.MSBinary, [], {},
              'BINARY'),
             (types.Binary, [10], {},
              'BINARY(10)'),
 
-            (mssql.MSBinary, [], {},
-             'IMAGE'),
             (mssql.MSBinary, [10], {},
              'BINARY(10)'),
 
@@ -700,7 +683,7 @@ class TypesTest2(TestBase, AssertsExecutionResults):
              'BINARY(10)')
             ]
 
-        table_args = ['test_mssql_binary', MetaData(testing.db)]
+        table_args = ['test_mssql_binary', metadata]
         for index, spec in enumerate(columns):
             type_, args, kw, res = spec
             table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
@@ -715,18 +698,15 @@ class TypesTest2(TestBase, AssertsExecutionResults):
                            "%s %s" % (col.name, columns[index][3]))
             self.assert_(repr(col))
 
-        binary_table.create(checkfirst=True)
+        metadata.create_all()
 
         reflected_binary = Table('test_mssql_binary', MetaData(testing.db), autoload=True)
         for col in reflected_binary.c:
-            # don't test the MSGenericBinary since it's a special case and
-            # reflected it will map to a MSImage or MSBinary depending
-            if not testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__ == mssql.MSGenericBinary:
-                testing.eq_(testing.db.dialect.type_descriptor(col.type).__class__,
-                    testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__)
+            c1 =testing.db.dialect.type_descriptor(col.type).__class__
+            c2 =testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__ 
+            assert issubclass(c1, c2), "%r is not a subclass of %r" % (c1, c2)
             if binary_table.c[col.name].type.length:
                 testing.eq_(col.type.length, binary_table.c[col.name].type.length)
-        binary_table.drop()
 
     def test_boolean(self):
         "Exercise type specification for boolean type."
@@ -737,7 +717,7 @@ class TypesTest2(TestBase, AssertsExecutionResults):
              'BIT'),
            ]
 
-        table_args = ['test_mssql_boolean', MetaData(testing.db)]
+        table_args = ['test_mssql_boolean', metadata]
         for index, spec in enumerate(columns):
             type_, args, kw, res = spec
             table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
@@ -752,12 +732,7 @@ class TypesTest2(TestBase, AssertsExecutionResults):
                            "%s %s" % (col.name, columns[index][3]))
             self.assert_(repr(col))
 
-        try:
-            boolean_table.create(checkfirst=True)
-            assert True
-        except:
-            raise
-        boolean_table.drop()
+        metadata.create_all()
 
     def test_numeric(self):
         "Exercise type specification and options for numeric types."
@@ -792,7 +767,7 @@ class TypesTest2(TestBase, AssertsExecutionResults):
              'SMALLINT'),
            ]
 
-        table_args = ['test_mssql_numeric', MetaData(testing.db)]
+        table_args = ['test_mssql_numeric', metadata]
         for index, spec in enumerate(columns):
             type_, args, kw, res = spec
             table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
@@ -807,20 +782,11 @@ class TypesTest2(TestBase, AssertsExecutionResults):
                            "%s %s" % (col.name, columns[index][3]))
             self.assert_(repr(col))
 
-        try:
-            numeric_table.create(checkfirst=True)
-            assert True
-        except:
-            raise
-        numeric_table.drop()
+        metadata.create_all()
 
     def test_char(self):
         """Exercise COLLATE-ish options on string types."""
 
-        # modify the text_as_varchar setting since we are not testing that behavior here
-        text_as_varchar = testing.db.dialect.text_as_varchar
-        testing.db.dialect.text_as_varchar = False
-
         columns = [
             (mssql.MSChar, [], {},
              'CHAR'),
@@ -861,7 +827,7 @@ class TypesTest2(TestBase, AssertsExecutionResults):
              'NTEXT COLLATE Latin1_General_CI_AS'),
            ]
 
-        table_args = ['test_mssql_charset', MetaData(testing.db)]
+        table_args = ['test_mssql_charset', metadata]
         for index, spec in enumerate(columns):
             type_, args, kw, res = spec
             table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
@@ -876,104 +842,79 @@ class TypesTest2(TestBase, AssertsExecutionResults):
                            "%s %s" % (col.name, columns[index][3]))
             self.assert_(repr(col))
 
-        try:
-            charset_table.create(checkfirst=True)
-            assert True
-        except:
-            raise
-        charset_table.drop()
-
-        testing.db.dialect.text_as_varchar = text_as_varchar
+        metadata.create_all()
 
     def test_timestamp(self):
         """Exercise TIMESTAMP column."""
 
-        meta = MetaData(testing.db)
         dialect = mssql.dialect()
 
-        try:
-            columns = [
-                (TIMESTAMP,
-                 'TIMESTAMP'),
-                ]
-            for idx, (spec, expected) in enumerate(columns):
-                t = Table('mssql_ts%s' % idx, meta,
-                          Column('id', Integer, primary_key=True),
-                          Column('t', spec, nullable=None))
-                gen = dialect.ddl_compiler(dialect, schema.CreateTable(t))
-                testing.eq_(gen.get_column_specification(t.c.t), "t %s" % expected)
-                self.assert_(repr(t.c.t))
-                try:
-                    t.create(checkfirst=True)
-                    assert True
-                except:
-                    raise
-                t.drop()
-        finally:
-            meta.drop_all()
+        spec, expected = (TIMESTAMP,'TIMESTAMP')
+        t = Table('mssql_ts', metadata,
+                  Column('id', Integer, primary_key=True),
+                  Column('t', spec, nullable=None))
+        gen = dialect.ddl_compiler(dialect, schema.CreateTable(t))
+        testing.eq_(gen.get_column_specification(t.c.t), "t %s" % expected)
+        self.assert_(repr(t.c.t))
+        t.create(checkfirst=True)
 
     @testing.crashes('mssql', 'FIXME: unknown')
     def test_autoincrement(self):
-        meta = MetaData(testing.db)
-        try:
-            Table('ai_1', meta,
-                  Column('int_y', Integer, primary_key=True),
-                  Column('int_n', Integer, DefaultClause('0'),
-                         primary_key=True))
-            Table('ai_2', meta,
-                  Column('int_y', Integer, primary_key=True),
-                  Column('int_n', Integer, DefaultClause('0'),
-                         primary_key=True))
-            Table('ai_3', meta,
-                  Column('int_n', Integer, DefaultClause('0'),
-                         primary_key=True, autoincrement=False),
-                  Column('int_y', Integer, primary_key=True))
-            Table('ai_4', meta,
-                  Column('int_n', Integer, DefaultClause('0'),
-                         primary_key=True, autoincrement=False),
-                  Column('int_n2', Integer, DefaultClause('0'),
-                         primary_key=True, autoincrement=False))
-            Table('ai_5', meta,
-                  Column('int_y', Integer, primary_key=True),
-                  Column('int_n', Integer, DefaultClause('0'),
-                         primary_key=True, autoincrement=False))
-            Table('ai_6', meta,
-                  Column('o1', String(1), DefaultClause('x'),
-                         primary_key=True),
-                  Column('int_y', Integer, primary_key=True))
-            Table('ai_7', meta,
-                  Column('o1', String(1), DefaultClause('x'),
-                         primary_key=True),
-                  Column('o2', String(1), DefaultClause('x'),
-                         primary_key=True),
-                  Column('int_y', Integer, primary_key=True))
-            Table('ai_8', meta,
-                  Column('o1', String(1), DefaultClause('x'),
-                         primary_key=True),
-                  Column('o2', String(1), DefaultClause('x'),
-                         primary_key=True))
-            meta.create_all()
-
-            table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4',
-                           'ai_5', 'ai_6', 'ai_7', 'ai_8']
-            mr = MetaData(testing.db)
-            mr.reflect(only=table_names)
-
-            for tbl in [mr.tables[name] for name in table_names]:
-                for c in tbl.c:
-                    if c.name.startswith('int_y'):
-                        assert c.autoincrement
-                    elif c.name.startswith('int_n'):
-                        assert not c.autoincrement
-                tbl.insert().execute()
-                if 'int_y' in tbl.c:
-                    assert select([tbl.c.int_y]).scalar() == 1
-                    assert list(tbl.select().execute().fetchone()).count(1) == 1
-                else:
-                    assert 1 not in list(tbl.select().execute().fetchone())
-        finally:
-            meta.drop_all()
+        Table('ai_1', metadata,
+              Column('int_y', Integer, primary_key=True),
+              Column('int_n', Integer, DefaultClause('0'),
+                     primary_key=True))
+        Table('ai_2', metadata,
+              Column('int_y', Integer, primary_key=True),
+              Column('int_n', Integer, DefaultClause('0'),
+                     primary_key=True))
+        Table('ai_3', metadata,
+              Column('int_n', Integer, DefaultClause('0'),
+                     primary_key=True, autoincrement=False),
+              Column('int_y', Integer, primary_key=True))
+        Table('ai_4', metadata,
+              Column('int_n', Integer, DefaultClause('0'),
+                     primary_key=True, autoincrement=False),
+              Column('int_n2', Integer, DefaultClause('0'),
+                     primary_key=True, autoincrement=False))
+        Table('ai_5', metadata,
+              Column('int_y', Integer, primary_key=True),
+              Column('int_n', Integer, DefaultClause('0'),
+                     primary_key=True, autoincrement=False))
+        Table('ai_6', metadata,
+              Column('o1', String(1), DefaultClause('x'),
+                     primary_key=True),
+              Column('int_y', Integer, primary_key=True))
+        Table('ai_7', metadata,
+              Column('o1', String(1), DefaultClause('x'),
+                     primary_key=True),
+              Column('o2', String(1), DefaultClause('x'),
+                     primary_key=True),
+              Column('int_y', Integer, primary_key=True))
+        Table('ai_8', metadata,
+              Column('o1', String(1), DefaultClause('x'),
+                     primary_key=True),
+              Column('o2', String(1), DefaultClause('x'),
+                     primary_key=True))
+        metadata.create_all()
 
+        table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4',
+                       'ai_5', 'ai_6', 'ai_7', 'ai_8']
+        mr = MetaData(testing.db)
+        mr.reflect(only=table_names)
+
+        for tbl in [mr.tables[name] for name in table_names]:
+            for c in tbl.c:
+                if c.name.startswith('int_y'):
+                    assert c.autoincrement
+                elif c.name.startswith('int_n'):
+                    assert not c.autoincrement
+            tbl.insert().execute()
+            if 'int_y' in tbl.c:
+                assert select([tbl.c.int_y]).scalar() == 1
+                assert list(tbl.select().execute().fetchone()).count(1) == 1
+            else:
+                assert 1 not in list(tbl.select().execute().fetchone())
 
 class BinaryTest(TestBase, AssertsExecutionResults):
     """Test the Binary and VarBinary types"""
@@ -1025,7 +966,12 @@ class BinaryTest(TestBase, AssertsExecutionResults):
         stream2 =self.load_stream('binary_data_two.dat')
         binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_image=stream1, data_slice=stream1[0:100], pickled=testobj1, mypickle=testobj3)
         binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_image=stream2, data_slice=stream2[0:99], pickled=testobj2)
-        binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_image=None, data_slice=stream2[0:99], pickled=None)
+        
+        # TODO: pyodbc does not seem to accept "None" for a VARBINARY column (data=None).
+        # error:  [Microsoft][ODBC SQL Server Driver][SQL Server]Implicit conversion from 
+        # data type varchar to varbinary is not allowed. Use the CONVERT function to run this query. (257)
+        #binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_image=None, data_slice=stream2[0:99], pickled=None)
+        binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data_image=None, data_slice=stream2[0:99], pickled=None)
 
         for stmt in (
             binary_table.select(order_by=binary_table.c.primary_id),