]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
got test_mssql passing except for those tests that seem to be freetds-related
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Aug 2009 20:56:59 +0000 (20:56 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Aug 2009 20:56:59 +0000 (20:56 +0000)
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/test/testing.py
test/dialect/test_mssql.py
test/sql/test_query.py

index a521932970c954bf06268bd422fae9fa97c54142..e3fa1d4922da70d4f09476455c3ba853e81a8cdd 100644 (file)
@@ -1156,7 +1156,12 @@ class MSDialect(default.DefaultDialect):
 
     def do_release_savepoint(self, connection, name):
         pass
-
+    
+    def initialize(self, connection):
+        super(MSDialect, self).initialize(connection)
+        if self.server_version_info >= MS_2005_VERSION and 'implicit_returning' not in self.__dict__:
+            self.implicit_returning = True
+        
     def get_default_schema_name(self, connection):
         return self.default_schema_name
         
@@ -1317,6 +1322,7 @@ class MSDialect(default.DefaultDialect):
                 'type' : coltype,
                 'nullable' : nullable,
                 'default' : default,
+                'autoincrement':False,
             }
             cols.append(cdict)
         # autoincrement and identity
@@ -1338,11 +1344,14 @@ class MSDialect(default.DefaultDialect):
                                     name='%s_identity' % col_name)
                 break
         cursor.close()
-        if not ic is None:
+        if ic is not None:
             try:
                 # is this table_fullname reliable?
                 table_fullname = "%s.%s" % (current_schema, tablename)
-                cursor = connection.execute("select ident_seed(?), ident_incr(?)", table_fullname, table_fullname)
+                cursor = connection.execute(
+                    sql.text("select ident_seed(:seed), ident_incr(:incr)"), 
+                    {'seed':table_fullname, 'incr':table_fullname}
+                )
                 row = cursor.fetchone()
                 cursor.close()
                 if not row is None:
index a6961aab50a248d6b680f4ff150a6c12e3a51330..231496676c11c2f82fbebacd152edb0cedcf8171 100644 (file)
@@ -123,6 +123,11 @@ class Table(SchemaItem, expression.TableClause):
         instance to be used for the table reflection.  If ``None``, the
         underlying MetaData's bound connectable will be used.
 
+    :param implicit_returning: True by default - indicates that 
+        RETURNING can be used by default to fetch newly inserted primary key 
+        values, for backends which support this.  Note that 
+        create_engine() also provides an implicit_returning flag.
+
     :param include_columns: A list of strings indicating a subset of columns to be loaded via
         the ``autoload`` operation; table columns who aren't present in
         this list will not be represented on the resulting ``Table``
@@ -216,6 +221,7 @@ class Table(SchemaItem, expression.TableClause):
         autoload_with = kwargs.pop('autoload_with', None)
         include_columns = kwargs.pop('include_columns', None)
 
+        self.implicit_returning = kwargs.pop('implicit_returning', True)
         self.quote = kwargs.pop('quote', None)
         self.quote_schema = kwargs.pop('quote_schema', None)
         if 'info' in kwargs:
@@ -285,7 +291,8 @@ class Table(SchemaItem, expression.TableClause):
         for col in self.primary_key:
             if col.autoincrement and \
                 isinstance(col.type, types.Integer) and \
-                not col.foreign_keys:
+                not col.foreign_keys and \
+                isinstance(col.default, (type(None), Sequence)):
 
                 return col
 
@@ -482,7 +489,7 @@ class Column(SchemaItem, expression.ColumnClause):
           
           Contrast this argument to ``server_default`` which creates a 
           default generator on the database side.
-
+        
         :param key: An optional string identifier which will identify this ``Column`` 
             object on the :class:`Table`.  When a key is provided, this is the
             only identifier referencing the ``Column`` within the application,
index 810057946dd67fb62ef5f07c6e7b852ff25247fe..c981785734584059e09e88db35b8b7c22ce01bf0 100644 (file)
@@ -691,9 +691,7 @@ class SQLCompiler(engine.Compiled):
         
         text += " INTO " + preparer.format_table(insert_stmt.table)
          
-        if not colparams and supports_default_values:
-            text += " DEFAULT VALUES"
-        else: 
+        if colparams or not supports_default_values:
             text += " (%s)" % ', '.join([preparer.format_column(c[0])
                        for c in colparams])
 
@@ -705,8 +703,10 @@ class SQLCompiler(engine.Compiled):
             if returning_clause.startswith("OUTPUT"):
                 text += " " + returning_clause
                 returning_clause = None
-                
-        if colparams or not supports_default_values:
+
+        if not colparams and supports_default_values:
+            text += " DEFAULT VALUES"
+        else:
             text += " VALUES (%s)" % \
                      ', '.join([c[1] for c in colparams])
         
@@ -780,6 +780,10 @@ class SQLCompiler(engine.Compiled):
 
         # create a list of column assignment clauses as tuples
         values = []
+        
+        implicit_returning = self.dialect.implicit_returning and \
+                                stmt.table.implicit_returning
+        
         for c in stmt.table.columns:
             if c.key in parameters:
                 value = parameters[c.key]
@@ -799,12 +803,12 @@ class SQLCompiler(engine.Compiled):
                     if c.primary_key and \
                         (
                             self.dialect.preexecute_pk_sequences or 
-                            self.dialect.implicit_returning
+                            implicit_returning
                         ) and \
                         not self.inline and \
                         not self.statement._returning:
 
-                        if self.dialect.implicit_returning:
+                        if implicit_returning:
                             if isinstance(c.default, schema.Sequence):
                                 proc = self.process(c.default)
                                 if proc is not None:
index 4a265fbec6010bd1479cef26ba46e8a5d00fb609..16a13d9d3b8f6bbd1b52a255593490b30227db97 100644 (file)
@@ -604,18 +604,13 @@ class AssertsCompiledSQL(object):
 
 class ComparesTables(object):
     def assert_tables_equal(self, table, reflected_table):
-        base_mro = sqltypes.TypeEngine.__mro__
         assert len(table.c) == len(reflected_table.c)
         for c, reflected_c in zip(table.c, reflected_table.c):
             eq_(c.name, reflected_c.name)
             assert reflected_c is reflected_table.c[c.name]
             eq_(c.primary_key, reflected_c.primary_key)
             eq_(c.nullable, reflected_c.nullable)
-            assert len(
-                set(type(reflected_c.type).__mro__).difference(base_mro).intersection(
-                set(type(c.type).__mro__).difference(base_mro)
-                )
-            ) > 0, "On column %r, type '%s' doesn't correspond to type '%s'" % (reflected_c.name, reflected_c.type, c.type)
+            self.assert_types_base(reflected_c, c)
 
             if isinstance(c.type, sqltypes.String):
                 eq_(c.type.length, reflected_c.type.length)
@@ -634,7 +629,14 @@ class ComparesTables(object):
         assert len(table.primary_key) == len(reflected_table.primary_key)
         for c in table.primary_key:
             assert reflected_table.primary_key.columns[c.name]
-
+    
+    def assert_types_base(self, c1, c2):
+        base_mro = sqltypes.TypeEngine.__mro__
+        assert len(
+            set(type(c1.type).__mro__).difference(base_mro).intersection(
+            set(type(c2.type).__mro__).difference(base_mro)
+            )
+        ) > 0, "On column %r, type '%s' doesn't correspond to type '%s'" % (c1.name, c1.type, c2.type)
 
 class AssertsExecutionResults(object):
     def assert_result(self, result, class_, *objects):
index d8a541abf0529b539e58aaf7fb4ca410cfc4609e..e2272c3ca4ee86d74338bd59ac9251983509e18c 100644 (file)
@@ -297,7 +297,7 @@ class ReflectionTest(TestBase, ComparesTables):
         finally:
             meta.drop_all()
 
-    def testidentity(self):
+    def test_identity(self):
         meta = MetaData(testing.db)
         table = Table(
             'identity_test', meta,
@@ -343,7 +343,9 @@ class QueryTest(TestBase):
         meta = MetaData(testing.db)
         t1 = Table('t1', meta,
                 Column('id', Integer, Sequence('fred', 100, 1), primary_key=True),
-                Column('descr', String(200)))
+                Column('descr', String(200)),
+                implicit_returning = False
+                )
         t2 = Table('t2', meta,
                 Column('id', Integer, Sequence('fred', 200, 1), primary_key=True),
                 Column('descr', String(200)))
@@ -647,7 +649,7 @@ class ParseConnectTest(TestBase, AssertsCompiledSQL):
         eq_([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
 
 
-class TypesTest(TestBase, AssertsExecutionResults):
+class TypesTest(TestBase, AssertsExecutionResults, ComparesTables):
     __only_on__ = 'mssql'
 
     @classmethod
@@ -766,7 +768,7 @@ class TypesTest(TestBase, AssertsExecutionResults):
              'TIME', ['>=', (10,)]),
             (mssql.MSTime, [], {},
              'TIME', ['>=', (10,)]),
-            (types.Time, [1], {},
+            (mssql.MSTime, [1], {},
              'TIME(1)', ['>=', (10,)]),
             (types.Time, [], {},
              'DATETIME', ['<', (10,)], mssql.MSDateTime),
@@ -807,10 +809,7 @@ class TypesTest(TestBase, AssertsExecutionResults):
 
         reflected_dates = Table('test_mssql_dates', MetaData(testing.db), autoload=True)
         for col in reflected_dates.c:
-            index = int(col.name[1:])
-            c1 = testing.db.dialect.type_descriptor(col.type).__class__
-            c2 = len(columns[index]) > 5 and columns[index][5] or columns[index][0]
-            assert issubclass(c1, c2), "%r is not a subclass of %r" % (c1, c2)
+            self.assert_types_base(col, dates_table.c[col.key])
 
     def test_date_roundtrip(self):
         t = Table('test_dates', metadata,
@@ -836,7 +835,7 @@ class TypesTest(TestBase, AssertsExecutionResults):
 
         t.insert().execute(adate=d1, adatetime=d2, atime=t1)
 
-        self.assertEquals(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)])
+        eq_(select([t.c.adate, t.c.atime, t.c.adatetime], t.c.adate==d1).execute().fetchall(), [(d1, t1, d2)])
 
     def test_binary(self):
         "Exercise type specification for binary types."
@@ -922,16 +921,14 @@ class TypesTest(TestBase, AssertsExecutionResults):
         columns = [
             # column type, args, kwargs, expected ddl
             (mssql.MSNumeric, [], {},
-             'NUMERIC(10, 2)'),
+             'NUMERIC'),
             (mssql.MSNumeric, [None], {},
              'NUMERIC'),
-            (mssql.MSNumeric, [12], {},
-             'NUMERIC(12, 2)'),
             (mssql.MSNumeric, [12, 4], {},
              'NUMERIC(12, 4)'),
 
             (types.Float, [], {},
-             'FLOAT(10)'),
+             'FLOAT'),
             (types.Float, [None], {},
              'FLOAT'),
             (types.Float, [12], {},
@@ -1040,7 +1037,6 @@ class TypesTest(TestBase, AssertsExecutionResults):
         self.assert_(repr(t.c.t))
         t.create(checkfirst=True)
         
-    @testing.crashes('mssql', 'FIXME: unknown')
     def test_autoincrement(self):
         Table('ai_1', metadata,
                Column('int_y', Integer, primary_key=True),
@@ -1083,21 +1079,27 @@ class TypesTest(TestBase, AssertsExecutionResults):
         table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4',
                         'ai_5', 'ai_6', 'ai_7', 'ai_8']
         mr = MetaData(testing.db)
-        mr.reflect(only=table_names)
 
-        for tbl in [mr.tables[name] for name in table_names]:
+        for name in table_names:
+            tbl = Table(name, mr, autoload=True)
             for c in tbl.c:
                 if c.name.startswith('int_y'):
                     assert c.autoincrement
                 elif c.name.startswith('int_n'):
                     assert not c.autoincrement
-            tbl.insert().execute()
-            if 'int_y' in tbl.c:
-                assert select([tbl.c.int_y]).scalar() == 1
-                assert list(tbl.select().execute().first()).count(1) == 1
-            else:
-                assert 1 not in list(tbl.select().execute().first())
-
+            
+            for counter, engine in enumerate([
+                engines.testing_engine(options={'implicit_returning':False}),
+                engines.testing_engine(options={'implicit_returning':True}),
+                ]
+            ):
+                engine.execute(tbl.insert())
+                if 'int_y' in tbl.c:
+                    assert engine.scalar(select([tbl.c.int_y])) == counter + 1
+                    assert list(engine.execute(tbl.select()).first()).count(counter + 1) == 1
+                else:
+                    assert 1 not in list(engine.execute(tbl.select()).first())
+                engine.execute(tbl.delete())
 
 class BinaryTest(TestBase, AssertsExecutionResults):
     """Test the Binary and VarBinary types"""
index 979c148e4ae8d0c129f16e9887912d7452fca4a0..b3a9eb0ccbe2adbb48b17b9af567055753487cb0 100644 (file)
@@ -80,7 +80,7 @@ class QueryTest(TestBase):
                     ret[c.key] = row[c]
             return ret
 
-        if testing.against('firebird', 'postgres', 'oracle'): #, 'mssql'):
+        if testing.against('firebird', 'postgres', 'oracle', 'mssql'):
             test_engines = [
                 engines.testing_engine(options={'implicit_returning':False}),
                 engines.testing_engine(options={'implicit_returning':True}),