From af5a8222532d3799ef8f540becf46bca5da44b4a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 2 Aug 2009 20:56:59 +0000 Subject: [PATCH] got test_mssql passing except for those tests that seem to be freetds-related --- lib/sqlalchemy/dialects/mssql/base.py | 15 +++++++-- lib/sqlalchemy/schema.py | 11 ++++-- lib/sqlalchemy/sql/compiler.py | 18 ++++++---- lib/sqlalchemy/test/testing.py | 16 +++++---- test/dialect/test_mssql.py | 48 ++++++++++++++------------- test/sql/test_query.py | 2 +- 6 files changed, 67 insertions(+), 43 deletions(-) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index a521932970..e3fa1d4922 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1156,7 +1156,12 @@ class MSDialect(default.DefaultDialect): def do_release_savepoint(self, connection, name): pass - + + def initialize(self, connection): + super(MSDialect, self).initialize(connection) + if self.server_version_info >= MS_2005_VERSION and 'implicit_returning' not in self.__dict__: + self.implicit_returning = True + def get_default_schema_name(self, connection): return self.default_schema_name @@ -1317,6 +1322,7 @@ class MSDialect(default.DefaultDialect): 'type' : coltype, 'nullable' : nullable, 'default' : default, + 'autoincrement':False, } cols.append(cdict) # autoincrement and identity @@ -1338,11 +1344,14 @@ class MSDialect(default.DefaultDialect): name='%s_identity' % col_name) break cursor.close() - if not ic is None: + if ic is not None: try: # is this table_fullname reliable? table_fullname = "%s.%s" % (current_schema, tablename) - cursor = connection.execute("select ident_seed(?), ident_incr(?)", table_fullname, table_fullname) + cursor = connection.execute( + sql.text("select ident_seed(:seed), ident_incr(:incr)"), + {'seed':table_fullname, 'incr':table_fullname} + ) row = cursor.fetchone() cursor.close() if not row is None: diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index a6961aab50..231496676c 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -123,6 +123,11 @@ class Table(SchemaItem, expression.TableClause): instance to be used for the table reflection. If ``None``, the underlying MetaData's bound connectable will be used. + :param implicit_returning: True by default - indicates that + RETURNING can be used by default to fetch newly inserted primary key + values, for backends which support this. Note that + create_engine() also provides an implicit_returning flag. + :param include_columns: A list of strings indicating a subset of columns to be loaded via the ``autoload`` operation; table columns who aren't present in this list will not be represented on the resulting ``Table`` @@ -216,6 +221,7 @@ class Table(SchemaItem, expression.TableClause): autoload_with = kwargs.pop('autoload_with', None) include_columns = kwargs.pop('include_columns', None) + self.implicit_returning = kwargs.pop('implicit_returning', True) self.quote = kwargs.pop('quote', None) self.quote_schema = kwargs.pop('quote_schema', None) if 'info' in kwargs: @@ -285,7 +291,8 @@ class Table(SchemaItem, expression.TableClause): for col in self.primary_key: if col.autoincrement and \ isinstance(col.type, types.Integer) and \ - not col.foreign_keys: + not col.foreign_keys and \ + isinstance(col.default, (type(None), Sequence)): return col @@ -482,7 +489,7 @@ class Column(SchemaItem, expression.ColumnClause): Contrast this argument to ``server_default`` which creates a default generator on the database side. - + :param key: An optional string identifier which will identify this ``Column`` object on the :class:`Table`. When a key is provided, this is the only identifier referencing the ``Column`` within the application, diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 810057946d..c981785734 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -691,9 +691,7 @@ class SQLCompiler(engine.Compiled): text += " INTO " + preparer.format_table(insert_stmt.table) - if not colparams and supports_default_values: - text += " DEFAULT VALUES" - else: + if colparams or not supports_default_values: text += " (%s)" % ', '.join([preparer.format_column(c[0]) for c in colparams]) @@ -705,8 +703,10 @@ class SQLCompiler(engine.Compiled): if returning_clause.startswith("OUTPUT"): text += " " + returning_clause returning_clause = None - - if colparams or not supports_default_values: + + if not colparams and supports_default_values: + text += " DEFAULT VALUES" + else: text += " VALUES (%s)" % \ ', '.join([c[1] for c in colparams]) @@ -780,6 +780,10 @@ class SQLCompiler(engine.Compiled): # create a list of column assignment clauses as tuples values = [] + + implicit_returning = self.dialect.implicit_returning and \ + stmt.table.implicit_returning + for c in stmt.table.columns: if c.key in parameters: value = parameters[c.key] @@ -799,12 +803,12 @@ class SQLCompiler(engine.Compiled): if c.primary_key and \ ( self.dialect.preexecute_pk_sequences or - self.dialect.implicit_returning + implicit_returning ) and \ not self.inline and \ not self.statement._returning: - if self.dialect.implicit_returning: + if implicit_returning: if isinstance(c.default, schema.Sequence): proc = self.process(c.default) if proc is not None: diff --git a/lib/sqlalchemy/test/testing.py b/lib/sqlalchemy/test/testing.py index 4a265fbec6..16a13d9d3b 100644 --- a/lib/sqlalchemy/test/testing.py +++ b/lib/sqlalchemy/test/testing.py @@ -604,18 +604,13 @@ class AssertsCompiledSQL(object): class ComparesTables(object): def assert_tables_equal(self, table, reflected_table): - base_mro = sqltypes.TypeEngine.__mro__ assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): eq_(c.name, reflected_c.name) assert reflected_c is reflected_table.c[c.name] eq_(c.primary_key, reflected_c.primary_key) eq_(c.nullable, reflected_c.nullable) - assert len( - set(type(reflected_c.type).__mro__).difference(base_mro).intersection( - set(type(c.type).__mro__).difference(base_mro) - ) - ) > 0, "On column %r, type '%s' doesn't correspond to type '%s'" % (reflected_c.name, reflected_c.type, c.type) + self.assert_types_base(reflected_c, c) if isinstance(c.type, sqltypes.String): eq_(c.type.length, reflected_c.type.length) @@ -634,7 +629,14 @@ class ComparesTables(object): assert len(table.primary_key) == len(reflected_table.primary_key) for c in table.primary_key: assert reflected_table.primary_key.columns[c.name] - + + def assert_types_base(self, c1, c2): + base_mro = sqltypes.TypeEngine.__mro__ + assert len( + set(type(c1.type).__mro__).difference(base_mro).intersection( + set(type(c2.type).__mro__).difference(base_mro) + ) + ) > 0, "On column %r, type '%s' doesn't correspond to type '%s'" % (c1.name, c1.type, c2.type) class AssertsExecutionResults(object): def assert_result(self, result, class_, *objects): diff --git a/test/dialect/test_mssql.py b/test/dialect/test_mssql.py index d8a541abf0..e2272c3ca4 100644 --- a/test/dialect/test_mssql.py +++ b/test/dialect/test_mssql.py @@ -297,7 +297,7 @@ class ReflectionTest(TestBase, ComparesTables): finally: meta.drop_all() - def testidentity(self): + def test_identity(self): meta = MetaData(testing.db) table = Table( 'identity_test', meta, @@ -343,7 +343,9 @@ class QueryTest(TestBase): meta = MetaData(testing.db) t1 = Table('t1', meta, Column('id', Integer, Sequence('fred', 100, 1), primary_key=True), - Column('descr', String(200))) + Column('descr', String(200)), + implicit_returning = False + ) t2 = Table('t2', meta, Column('id', Integer, Sequence('fred', 200, 1), primary_key=True), Column('descr', String(200))) @@ -647,7 +649,7 @@ class ParseConnectTest(TestBase, AssertsCompiledSQL): eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection) -class TypesTest(TestBase, AssertsExecutionResults): +class TypesTest(TestBase, AssertsExecutionResults, ComparesTables): __only_on__ = 'mssql' @classmethod @@ -766,7 +768,7 @@ class TypesTest(TestBase, AssertsExecutionResults): 'TIME', ['>=', (10,)]), (mssql.MSTime, [], {}, 'TIME', ['>=', (10,)]), - (types.Time, [1], {}, + (mssql.MSTime, [1], {}, 'TIME(1)', ['>=', (10,)]), (types.Time, [], {}, 'DATETIME', ['<', (10,)], mssql.MSDateTime), @@ -807,10 +809,7 @@ class TypesTest(TestBase, AssertsExecutionResults): reflected_dates = Table('test_mssql_dates', MetaData(testing.db), autoload=True) for col in reflected_dates.c: - index = int(col.name[1:]) - c1 = testing.db.dialect.type_descriptor(col.type).__class__ - c2 = len(columns[index]) > 5 and columns[index][5] or columns[index][0] - assert issubclass(c1, c2), "%r is not a subclass of %r" % (c1, c2) + self.assert_types_base(col, dates_table.c[col.key]) def test_date_roundtrip(self): t = Table('test_dates', metadata, @@ -836,7 +835,7 @@ class TypesTest(TestBase, AssertsExecutionResults): t.insert().execute(adate=d1, adatetime=d2, atime=t1) - self.assertEquals(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)]) + eq_(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)]) def test_binary(self): "Exercise type specification for binary types." @@ -922,16 +921,14 @@ class TypesTest(TestBase, AssertsExecutionResults): columns = [ # column type, args, kwargs, expected ddl (mssql.MSNumeric, [], {}, - 'NUMERIC(10, 2)'), + 'NUMERIC'), (mssql.MSNumeric, [None], {}, 'NUMERIC'), - (mssql.MSNumeric, [12], {}, - 'NUMERIC(12, 2)'), (mssql.MSNumeric, [12, 4], {}, 'NUMERIC(12, 4)'), (types.Float, [], {}, - 'FLOAT(10)'), + 'FLOAT'), (types.Float, [None], {}, 'FLOAT'), (types.Float, [12], {}, @@ -1040,7 +1037,6 @@ class TypesTest(TestBase, AssertsExecutionResults): self.assert_(repr(t.c.t)) t.create(checkfirst=True) - @testing.crashes('mssql', 'FIXME: unknown') def test_autoincrement(self): Table('ai_1', metadata, Column('int_y', Integer, primary_key=True), @@ -1083,21 +1079,27 @@ class TypesTest(TestBase, AssertsExecutionResults): table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4', 'ai_5', 'ai_6', 'ai_7', 'ai_8'] mr = MetaData(testing.db) - mr.reflect(only=table_names) - for tbl in [mr.tables[name] for name in table_names]: + for name in table_names: + tbl = Table(name, mr, autoload=True) for c in tbl.c: if c.name.startswith('int_y'): assert c.autoincrement elif c.name.startswith('int_n'): assert not c.autoincrement - tbl.insert().execute() - if 'int_y' in tbl.c: - assert select([tbl.c.int_y]).scalar() == 1 - assert list(tbl.select().execute().first()).count(1) == 1 - else: - assert 1 not in list(tbl.select().execute().first()) - + + for counter, engine in enumerate([ + engines.testing_engine(options={'implicit_returning':False}), + engines.testing_engine(options={'implicit_returning':True}), + ] + ): + engine.execute(tbl.insert()) + if 'int_y' in tbl.c: + assert engine.scalar(select([tbl.c.int_y])) == counter + 1 + assert list(engine.execute(tbl.select()).first()).count(counter + 1) == 1 + else: + assert 1 not in list(engine.execute(tbl.select()).first()) + engine.execute(tbl.delete()) class BinaryTest(TestBase, AssertsExecutionResults): """Test the Binary and VarBinary types""" diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 979c148e4a..b3a9eb0ccb 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -80,7 +80,7 @@ class QueryTest(TestBase): ret[c.key] = row[c] return ret - if testing.against('firebird', 'postgres', 'oracle'): #, 'mssql'): + if testing.against('firebird', 'postgres', 'oracle', 'mssql'): test_engines = [ engines.testing_engine(options={'implicit_returning':False}), engines.testing_engine(options={'implicit_returning':True}), -- 2.47.3