From: Mike Bayer Date: Sun, 25 Jan 2009 20:52:02 +0000 (+0000) Subject: mssql type fixes.... X-Git-Tag: rel_0_6_6~316 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cff881cb4ce41d2199c92265dc323080a09c89ef;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git mssql type fixes.... --- diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 1964b6ddc5..c9c6fd729e 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -258,11 +258,7 @@ class MSNumeric(sqltypes.Numeric): def bind_processor(self, dialect): def process(value): - if value is None: - # Not sure that this exception is needed - return value - - elif isinstance(value, decimal.Decimal): + if isinstance(value, decimal.Decimal): if value.adjusted() < 0: result = "%s0.%s%s" % ( (value < 0 and '-' or ''), @@ -309,7 +305,7 @@ class MSTinyInteger(sqltypes.Integer): # filter bind parameters into datetime objects (required by pyodbc, # not sure about other dialects). -class MSDate(sqltypes.Date): +class MSDate(sqltypes.DATE): def bind_processor(self, dialect): def process(value): if type(value) == datetime.date: @@ -329,7 +325,7 @@ class MSDate(sqltypes.Date): return value return process -class MSTime(sqltypes.Time): +class MSTime(sqltypes.TIME): def __init__(self, precision=None, **kwargs): self.precision = precision super(MSTime, self).__init__() @@ -356,7 +352,7 @@ class MSTime(sqltypes.Time): return value return process -class MSDateTime(sqltypes.DateTime): +class _DateTimeBase(object): def bind_processor(self, dialect): def process(value): if type(value) == datetime.date: @@ -364,11 +360,14 @@ class MSDateTime(sqltypes.DateTime): else: return value return process + +class MSDateTime(_DateTimeBase, sqltypes.DATETIME): + pass -class MSSmallDateTime(MSDateTime): +class MSSmallDateTime(_DateTimeBase, sqltypes.DateTime): __visit_name__ = 'SMALLDATETIME' -class MSDateTime2(MSDateTime): +class MSDateTime2(_DateTimeBase, sqltypes.DateTime): __visit_name__ = 'DATETIME2' def __init__(self, precision=None, **kwargs): @@ -549,11 +548,12 @@ class MSNChar(_StringType, sqltypes.NCHAR): sqltypes.NCHAR.__init__(self, *args, **kw) class MSBinary(sqltypes.Binary): - pass + __visit_name__ = 'BINARY' class MSVarBinary(sqltypes.Binary): __visit_name__ = 'VARBINARY' + class MSImage(sqltypes.Binary): __visit_name__ = 'IMAGE' @@ -667,6 +667,8 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): def visit_date(self, type_): if self.dialect.server_version_info < MS_2008_VERSION: + import pdb + pdb.set_trace() return self.visit_DATETIME(type_) else: return self.visit_DATE(type_) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 7305f497ef..216c0a76b2 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1011,6 +1011,9 @@ class GenericTypeCompiler(engine.TypeCompiler): def visit_date(self, type_): return self.visit_DATE(type_) + + def visit_big_integer(self, type_): + return self.visit_BIGINT(type_) def visit_small_integer(self, type_): return self.visit_SMALLINT(type_) diff --git a/test/dialect/mssql.py b/test/dialect/mssql.py index bebda1752e..eec5518db2 100755 --- a/test/dialect/mssql.py +++ b/test/dialect/mssql.py @@ -474,11 +474,11 @@ class ParseConnectTest(TestBase, AssertsCompiledSQL): self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection) -class TypesTest(TestBase): +class TypesTest(TestBase, AssertsExecutionResults): __only_on__ = 'mssql' def setUpAll(self): - global numeric_table, metadata + global metadata metadata = MetaData(testing.db) def tearDown(self): @@ -534,11 +534,6 @@ class TypesTest(TestBase): raise e -class TypesTest2(TestBase, AssertsExecutionResults): - "Test Microsoft SQL Server column types" - - __only_on__ = 'mssql' - def test_money(self): "Exercise type specification for money types." @@ -550,7 +545,7 @@ class TypesTest2(TestBase, AssertsExecutionResults): 'SMALLMONEY'), ] - table_args = ['test_mssql_money', MetaData(testing.db)] + table_args = ['test_mssql_money', metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) @@ -611,15 +606,14 @@ class TypesTest2(TestBase, AssertsExecutionResults): ] - table_args = ['test_mssql_dates', MetaData(testing.db)] + table_args = ['test_mssql_dates', metadata] for index, spec in enumerate(columns): type_, args, kw, res, requires = spec[0:5] if (requires and testing._is_excluded('mssql', *requires)) or not requires: table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) dates_table = Table(*table_args) - dialect = mssql.dialect() - gen = dialect.ddl_compiler(dialect, schema.CreateTable(dates_table)) + gen = testing.db.dialect.ddl_compiler(testing.db.dialect, schema.CreateTable(dates_table)) for col in dates_table.c: index = int(col.name[1:]) @@ -627,62 +621,51 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - try: - dates_table.create(checkfirst=True) - assert True - except: - raise + dates_table.create(checkfirst=True) reflected_dates = Table('test_mssql_dates', MetaData(testing.db), autoload=True) for col in reflected_dates.c: index = int(col.name[1:]) - testing.eq_(testing.db.dialect.type_descriptor(col.type).__class__, - len(columns[index]) > 5 and columns[index][5] or columns[index][0]) - dates_table.drop() + c1 = testing.db.dialect.type_descriptor(col.type).__class__ + c2 = len(columns[index]) > 5 and columns[index][5] or columns[index][0] + assert issubclass(c1, c2), "%r is not a subclass of %r" % (c1, c2) - def test_dates2(self): - meta = MetaData(testing.db) - t = Table('test_dates', meta, + def test_date_roundtrip(self): + t = Table('test_dates', metadata, Column('id', Integer, Sequence('datetest_id_seq', optional=True), primary_key=True), Column('adate', Date), Column('atime', Time), Column('adatetime', DateTime)) - t.create(checkfirst=True) - try: - d1 = datetime.date(2007, 10, 30) - t1 = datetime.time(11, 2, 32) - d2 = datetime.datetime(2007, 10, 30, 11, 2, 32) - t.insert().execute(adate=d1, adatetime=d2, atime=t1) - t.insert().execute(adate=d2, adatetime=d2, atime=d2) - - x = t.select().execute().fetchall()[0] - self.assert_(x.adate.__class__ == datetime.date) - self.assert_(x.atime.__class__ == datetime.time) - self.assert_(x.adatetime.__class__ == datetime.datetime) + metadata.create_all() + d1 = datetime.date(2007, 10, 30) + t1 = datetime.time(11, 2, 32) + d2 = datetime.datetime(2007, 10, 30, 11, 2, 32) + t.insert().execute(adate=d1, adatetime=d2, atime=t1) + t.insert().execute(adate=d2, adatetime=d2, atime=d2) - t.delete().execute() + x = t.select().execute().fetchall()[0] + self.assert_(x.adate.__class__ == datetime.date) + self.assert_(x.atime.__class__ == datetime.time) + self.assert_(x.adatetime.__class__ == datetime.datetime) - t.insert().execute(adate=d1, adatetime=d2, atime=t1) + t.delete().execute() - self.assertEquals(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)]) + t.insert().execute(adate=d1, adatetime=d2, atime=t1) - finally: - t.drop(checkfirst=True) + self.assertEquals(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)]) def test_binary(self): "Exercise type specification for binary types." columns = [ # column type, args, kwargs, expected ddl - (types.Binary, [], {}, + (mssql.MSBinary, [], {}, 'BINARY'), (types.Binary, [10], {}, 'BINARY(10)'), - (mssql.MSBinary, [], {}, - 'IMAGE'), (mssql.MSBinary, [10], {}, 'BINARY(10)'), @@ -700,7 +683,7 @@ class TypesTest2(TestBase, AssertsExecutionResults): 'BINARY(10)') ] - table_args = ['test_mssql_binary', MetaData(testing.db)] + table_args = ['test_mssql_binary', metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) @@ -715,18 +698,15 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - binary_table.create(checkfirst=True) + metadata.create_all() reflected_binary = Table('test_mssql_binary', MetaData(testing.db), autoload=True) for col in reflected_binary.c: - # don't test the MSGenericBinary since it's a special case and - # reflected it will map to a MSImage or MSBinary depending - if not testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__ == mssql.MSGenericBinary: - testing.eq_(testing.db.dialect.type_descriptor(col.type).__class__, - testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__) + c1 =testing.db.dialect.type_descriptor(col.type).__class__ + c2 =testing.db.dialect.type_descriptor(binary_table.c[col.name].type).__class__ + assert issubclass(c1, c2), "%r is not a subclass of %r" % (c1, c2) if binary_table.c[col.name].type.length: testing.eq_(col.type.length, binary_table.c[col.name].type.length) - binary_table.drop() def test_boolean(self): "Exercise type specification for boolean type." @@ -737,7 +717,7 @@ class TypesTest2(TestBase, AssertsExecutionResults): 'BIT'), ] - table_args = ['test_mssql_boolean', MetaData(testing.db)] + table_args = ['test_mssql_boolean', metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) @@ -752,12 +732,7 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - try: - boolean_table.create(checkfirst=True) - assert True - except: - raise - boolean_table.drop() + metadata.create_all() def test_numeric(self): "Exercise type specification and options for numeric types." @@ -792,7 +767,7 @@ class TypesTest2(TestBase, AssertsExecutionResults): 'SMALLINT'), ] - table_args = ['test_mssql_numeric', MetaData(testing.db)] + table_args = ['test_mssql_numeric', metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) @@ -807,20 +782,11 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - try: - numeric_table.create(checkfirst=True) - assert True - except: - raise - numeric_table.drop() + metadata.create_all() def test_char(self): """Exercise COLLATE-ish options on string types.""" - # modify the text_as_varchar setting since we are not testing that behavior here - text_as_varchar = testing.db.dialect.text_as_varchar - testing.db.dialect.text_as_varchar = False - columns = [ (mssql.MSChar, [], {}, 'CHAR'), @@ -861,7 +827,7 @@ class TypesTest2(TestBase, AssertsExecutionResults): 'NTEXT COLLATE Latin1_General_CI_AS'), ] - table_args = ['test_mssql_charset', MetaData(testing.db)] + table_args = ['test_mssql_charset', metadata] for index, spec in enumerate(columns): type_, args, kw, res = spec table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) @@ -876,104 +842,79 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - try: - charset_table.create(checkfirst=True) - assert True - except: - raise - charset_table.drop() - - testing.db.dialect.text_as_varchar = text_as_varchar + metadata.create_all() def test_timestamp(self): """Exercise TIMESTAMP column.""" - meta = MetaData(testing.db) dialect = mssql.dialect() - try: - columns = [ - (TIMESTAMP, - 'TIMESTAMP'), - ] - for idx, (spec, expected) in enumerate(columns): - t = Table('mssql_ts%s' % idx, meta, - Column('id', Integer, primary_key=True), - Column('t', spec, nullable=None)) - gen = dialect.ddl_compiler(dialect, schema.CreateTable(t)) - testing.eq_(gen.get_column_specification(t.c.t), "t %s" % expected) - self.assert_(repr(t.c.t)) - try: - t.create(checkfirst=True) - assert True - except: - raise - t.drop() - finally: - meta.drop_all() + spec, expected = (TIMESTAMP,'TIMESTAMP') + t = Table('mssql_ts', metadata, + Column('id', Integer, primary_key=True), + Column('t', spec, nullable=None)) + gen = dialect.ddl_compiler(dialect, schema.CreateTable(t)) + testing.eq_(gen.get_column_specification(t.c.t), "t %s" % expected) + self.assert_(repr(t.c.t)) + t.create(checkfirst=True) @testing.crashes('mssql', 'FIXME: unknown') def test_autoincrement(self): - meta = MetaData(testing.db) - try: - Table('ai_1', meta, - Column('int_y', Integer, primary_key=True), - Column('int_n', Integer, DefaultClause('0'), - primary_key=True)) - Table('ai_2', meta, - Column('int_y', Integer, primary_key=True), - Column('int_n', Integer, DefaultClause('0'), - primary_key=True)) - Table('ai_3', meta, - Column('int_n', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False), - Column('int_y', Integer, primary_key=True)) - Table('ai_4', meta, - Column('int_n', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False), - Column('int_n2', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False)) - Table('ai_5', meta, - Column('int_y', Integer, primary_key=True), - Column('int_n', Integer, DefaultClause('0'), - primary_key=True, autoincrement=False)) - Table('ai_6', meta, - Column('o1', String(1), DefaultClause('x'), - primary_key=True), - Column('int_y', Integer, primary_key=True)) - Table('ai_7', meta, - Column('o1', String(1), DefaultClause('x'), - primary_key=True), - Column('o2', String(1), DefaultClause('x'), - primary_key=True), - Column('int_y', Integer, primary_key=True)) - Table('ai_8', meta, - Column('o1', String(1), DefaultClause('x'), - primary_key=True), - Column('o2', String(1), DefaultClause('x'), - primary_key=True)) - meta.create_all() - - table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4', - 'ai_5', 'ai_6', 'ai_7', 'ai_8'] - mr = MetaData(testing.db) - mr.reflect(only=table_names) - - for tbl in [mr.tables[name] for name in table_names]: - for c in tbl.c: - if c.name.startswith('int_y'): - assert c.autoincrement - elif c.name.startswith('int_n'): - assert not c.autoincrement - tbl.insert().execute() - if 'int_y' in tbl.c: - assert select([tbl.c.int_y]).scalar() == 1 - assert list(tbl.select().execute().fetchone()).count(1) == 1 - else: - assert 1 not in list(tbl.select().execute().fetchone()) - finally: - meta.drop_all() + Table('ai_1', metadata, + Column('int_y', Integer, primary_key=True), + Column('int_n', Integer, DefaultClause('0'), + primary_key=True)) + Table('ai_2', metadata, + Column('int_y', Integer, primary_key=True), + Column('int_n', Integer, DefaultClause('0'), + primary_key=True)) + Table('ai_3', metadata, + Column('int_n', Integer, DefaultClause('0'), + primary_key=True, autoincrement=False), + Column('int_y', Integer, primary_key=True)) + Table('ai_4', metadata, + Column('int_n', Integer, DefaultClause('0'), + primary_key=True, autoincrement=False), + Column('int_n2', Integer, DefaultClause('0'), + primary_key=True, autoincrement=False)) + Table('ai_5', metadata, + Column('int_y', Integer, primary_key=True), + Column('int_n', Integer, DefaultClause('0'), + primary_key=True, autoincrement=False)) + Table('ai_6', metadata, + Column('o1', String(1), DefaultClause('x'), + primary_key=True), + Column('int_y', Integer, primary_key=True)) + Table('ai_7', metadata, + Column('o1', String(1), DefaultClause('x'), + primary_key=True), + Column('o2', String(1), DefaultClause('x'), + primary_key=True), + Column('int_y', Integer, primary_key=True)) + Table('ai_8', metadata, + Column('o1', String(1), DefaultClause('x'), + primary_key=True), + Column('o2', String(1), DefaultClause('x'), + primary_key=True)) + metadata.create_all() + table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4', + 'ai_5', 'ai_6', 'ai_7', 'ai_8'] + mr = MetaData(testing.db) + mr.reflect(only=table_names) + + for tbl in [mr.tables[name] for name in table_names]: + for c in tbl.c: + if c.name.startswith('int_y'): + assert c.autoincrement + elif c.name.startswith('int_n'): + assert not c.autoincrement + tbl.insert().execute() + if 'int_y' in tbl.c: + assert select([tbl.c.int_y]).scalar() == 1 + assert list(tbl.select().execute().fetchone()).count(1) == 1 + else: + assert 1 not in list(tbl.select().execute().fetchone()) class BinaryTest(TestBase, AssertsExecutionResults): """Test the Binary and VarBinary types""" @@ -1025,7 +966,12 @@ class BinaryTest(TestBase, AssertsExecutionResults): stream2 =self.load_stream('binary_data_two.dat') binary_table.insert().execute(primary_id=1, misc='binary_data_one.dat', data=stream1, data_image=stream1, data_slice=stream1[0:100], pickled=testobj1, mypickle=testobj3) binary_table.insert().execute(primary_id=2, misc='binary_data_two.dat', data=stream2, data_image=stream2, data_slice=stream2[0:99], pickled=testobj2) - binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_image=None, data_slice=stream2[0:99], pickled=None) + + # TODO: pyodbc does not seem to accept "None" for a VARBINARY column (data=None). + # error: [Microsoft][ODBC SQL Server Driver][SQL Server]Implicit conversion from + # data type varchar to varbinary is not allowed. Use the CONVERT function to run this query. (257) + #binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data=None, data_image=None, data_slice=stream2[0:99], pickled=None) + binary_table.insert().execute(primary_id=3, misc='binary_data_two.dat', data_image=None, data_slice=stream2[0:99], pickled=None) for stmt in ( binary_table.select(order_by=binary_table.c.primary_id),