]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
MSSQL now passes still more unit tests [ticket:481]
authorRick Morrison <rickmorrison@gmail.com>
Thu, 15 Mar 2007 02:31:15 +0000 (02:31 +0000)
committerRick Morrison <rickmorrison@gmail.com>
Thu, 15 Mar 2007 02:31:15 +0000 (02:31 +0000)
Fix to null FLOAT fields in mssql-trusted.patch
MSSQL: LIMIT with OFFSET now raises an error
MSSQL: can now specify Windows authorization
MSSQL: ignores seconds on DATE columns (DATE fix, part 1)

CHANGES
lib/sqlalchemy/databases/mssql.py
test/engine/pool.py
test/engine/reflection.py
test/ext/selectresults.py
test/orm/generative.py
test/orm/mapper.py
test/sql/query.py

diff --git a/CHANGES b/CHANGES
index 745922297b3e5001b5d3ca74d1e4ad1844258f53..9f72facd3a09813b0d7d8b29ccb1291397c65295 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -48,6 +48,7 @@
     - the "else_" parameter to the case statement now properly works when
     set to zero.
 
+
 - oracle:
     - got binary working for any size input !  cx_oracle works fine,
       it was my fault as BINARY was being passed and not BLOB for
 
     - query() method is added by assignmapper.  this helps with 
       navigating to all the new generative methods on Query.
+
+- ms-sql:
+    - removed seconds input on DATE column types (probably 
+        should remove the time altogether)
+
+    - null values in float fields no longer raise errors
+
+    - LIMIT with OFFSET now raises an error (MS-SQL has no OFFSET support)
+
+
     
 0.3.5
 - sql:
index 8c3c71f6edd39d991da54d389d46e7ce1ecd4f01..60d52d1819b79be0abd5515e74a20abd5669ce77 100644 (file)
@@ -61,8 +61,16 @@ def use_adodbapi():
     # ADODBAPI has a non-standard Connection method
     connect = dbmodule.Connection
     def make_connect_string(keys):
-        return  [["Provider=SQLOLEDB;Data Source=%s;User Id=%s;Password=%s;Initial Catalog=%s" % (
-            keys.get("host"), keys.get("user"), keys.get("password", ""), keys.get("database"))], {}]        
+        connectors = ["Provider=SQLOLEDB"]
+        connectors.append ("Data Source=%s" % keys.get("host"))
+        connectors.append ("Initial Catalog=%s" % keys.get("database"))
+        user = keys.get("user")
+        if user:
+            connectors.append("User Id=%s" % user)
+            connectors.append("Password=%s" % keys.get("password", ""))
+        else:
+            connectors.append("Integrated Security=SSPI")
+        return [[";".join (connectors)], {}]
     sane_rowcount = True
     dialect = MSSQLDialect
     colspecs[sqltypes.Unicode] = AdoMSUnicode
@@ -91,8 +99,16 @@ def use_pyodbc():
     import pyodbc as dbmodule
     connect = dbmodule.connect
     def make_connect_string(keys):
-        return [["Driver={SQL Server};Server=%s;UID=%s;PWD=%s;Database=%s" % (
-            keys.get("host"), keys.get("user"), keys.get("password", ""), keys.get("database"))], {}]        
+        connectors = ["Driver={SQL Server}"]
+        connectors.append("Server=%s" % keys.get("host"))
+        connectors.append("Database=%s" % keys.get("database"))
+        user = keys.get("user")
+        if user:
+            connectors.append("UID=%s" % user)
+            connectors.append("PWD=%s" % keys.get("password", ""))
+        else:
+            connectors.append ("TrustedConnection=Yes")
+        return [[";".join (connectors)], {}]
     do_commit = True
     sane_rowcount = False
     dialect = MSSQLDialect
@@ -150,7 +166,7 @@ class MSFloat(sqltypes.Float):
 
     def convert_bind_param(self, value, dialect):
         """By converting to string, we can use Decimal types round-trip."""
-        return str(value) 
+        return value and str(value) or None
 
 class MSInteger(sqltypes.Integer):
     def get_col_spec(self):
@@ -195,7 +211,7 @@ class MSDate(sqltypes.Date):
     
     def convert_bind_param(self, value, dialect):
         if value and hasattr(value, "isoformat"):
-            return value.strftime('%Y-%m-%d %H:%M:%S')
+            return value.strftime('%Y-%m-%d %H:%M')
         return value
 
     def convert_result_value(self, value, dialect):
@@ -327,26 +343,29 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
     def __init__(self, dialect):
         self.IINSERT = self.HASIDENT = False
         super(MSSQLExecutionContext, self).__init__(dialect)
-    
+
+    def _has_implicit_sequence(self, column):
+        if column.primary_key and column.autoincrement:
+            if isinstance(column.type, sqltypes.Integer) and not column.foreign_key:
+                if column.default is None or (isinstance(column.default, schema.Sequence) and \
+                                              column.default.optional):
+                    return True
+        return False
+
     def pre_exec(self, engine, proxy, compiled, parameters, **kwargs):
         """MS-SQL has a special mode for inserting non-NULL values
         into IDENTITY columns.
         
         Activate it if the feature is turned on and needed.
         """
-        
         if getattr(compiled, "isinsert", False):
             tbl = compiled.statement.table
-            if not hasattr(tbl, 'has_sequence'):                
+            if not hasattr(tbl, 'has_sequence'):
+                tbl.has_sequence = False
                 for column in tbl.c:
-                    if column.primary_key and column.autoincrement and \
-                           isinstance(column.type, sqltypes.Integer) and not column.foreign_key:
-                        if column.default is None or (isinstance(column.default, schema.Sequence) and \
-                                                      column.default.optional):
-                            tbl.has_sequence = column
-                            break
-                else:
-                    tbl.has_sequence = False
+                    if getattr(column, 'sequence', False) or self._has_implicit_sequence(column):
+                        tbl.has_sequence = column
+                        break
 
             self.HASIDENT = bool(tbl.has_sequence)
             if engine.dialect.auto_identity_insert and self.HASIDENT:
@@ -520,6 +539,10 @@ class MSSQLDialect(ansisql.ANSIDialect):
                 row[columns.c.column_default]
             )
 
+            # cope with varchar(max)
+            if charlen == -1:
+                charlen = None
+                
             args = []
             for a in (charlen, numericprec, numericscale):
                 if a is not None:
@@ -644,12 +667,14 @@ class MSSQLCompiler(ansisql.ANSICompiler):
     def visit_select_precolumns(self, select):
         """ MS-SQL puts TOP, it's version of LIMIT here """
         s = select.distinct and "DISTINCT " or ""
-        if (select.limit):
+        if select.limit:
             s += "TOP %s " % (select.limit,)
+        if select.offset:
+            raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset')
         return s
 
-    def limit_clause(self, select):
-        # Limit in mssql is after the select keyword; MSsql has no support for offset
+    def limit_clause(self, select):    
+        # Limit in mssql is after the select keyword
         return ""
             
     def visit_table(self, table):
@@ -744,3 +769,4 @@ class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
 
 use_default()
 
+
index 08df106ce267b31e9d2186a974ef5d1062642aea..db97ea6f8db8505b460d692b3ec64aa1597a2285 100644 (file)
@@ -162,7 +162,7 @@ class PoolTest(PersistTest):
         c2 = p.connect()
         assert id(c2.connection) == c_id
         c2.close()
-        time.sleep(3)
+        time.sleep(4)
         c3= p.connect()
         assert id(c3.connection) != c_id
     
index 51a3d35c675768303a564b3bcc9866c53b1e0e1a..62cd92b6e6e358ef2a50b25a46ac6ef8965b500a 100644 (file)
@@ -267,6 +267,7 @@ class ReflectionTest(PersistTest):
             testbase.db.execute("drop table django_admin_log")
             testbase.db.execute("drop table django_content_type")
 
+    @testbase.unsupported('mssql')
     def testmultipk(self):
         """test that creating a table checks for a sequence before creating it"""
         meta = BoundMetaData(testbase.db)
index 88476c9cc0b96a59428070e02b6458c0192a9543..8df416be94c9ce74475ab64a2065bd244665bc9e 100644 (file)
@@ -39,7 +39,8 @@ class SelectResultsTest(PersistTest):
         res = self.query.select_by(range=5)
         assert res.order_by([Foo.c.bar])[0].bar == 5
         assert res.order_by([desc(Foo.c.bar)])[0].bar == 95
-        
+
+    @testbase.unsupported('mssql')
     def test_slice(self):
         assert self.res[1] == self.orig[1]
         assert list(self.res[10:20]) == self.orig[10:20]
@@ -50,6 +51,11 @@ class SelectResultsTest(PersistTest):
         assert list(self.res[-5:]) == self.orig[-5:]
         assert self.res[10:20][5] == self.orig[10:20][5]
 
+    @testbase.supported('mssql')
+    def test_slice_mssql(self):
+        assert list(self.res[:10]) == self.orig[:10]
+        assert list(self.res[:10]) == self.orig[:10]
+
     def test_aggregate(self):
         assert self.res.count() == 100
         assert self.res.filter(foo.c.bar<30).min(foo.c.bar) == 0
@@ -60,11 +66,14 @@ class SelectResultsTest(PersistTest):
         # this one fails in mysql as the result comes back as a string
         assert self.res.filter(foo.c.bar<30).sum(foo.c.bar) == 435
 
-    @testbase.unsupported('postgres', 'mysql', 'firebird')
+    @testbase.unsupported('postgres', 'mysql', 'firebird', 'mssql')
     def test_aggregate_2(self):
-        # this one fails with postgres, the floating point comparison fails
         assert self.res.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5
 
+    @testbase.supported('postgres', 'mysql', 'firebird', 'mssql')
+    def test_aggregate_2_int(self):
+        assert int(self.res.filter(foo.c.bar<30).avg(foo.c.bar)) == 14
+
     def test_filter(self):
         assert self.res.count() == 100
         assert self.res.filter(Foo.c.bar < 30).count() == 30
index 37ce1dcc9bc53d772799b2601814b4d6142a9c20..b8c2a85e1b5b7fd7b7feb348537d8a06b5155671 100644 (file)
@@ -39,6 +39,7 @@ class GenerativeQueryTest(PersistTest):
         assert res.order_by([Foo.c.bar])[0].bar == 5
         assert res.order_by([desc(Foo.c.bar)])[0].bar == 95
         
+    @testbase.unsupported('mssql')
     def test_slice(self):
         assert self.query[1] == self.orig[1]
         assert list(self.query[10:20]) == self.orig[10:20]
@@ -49,6 +50,11 @@ class GenerativeQueryTest(PersistTest):
         assert list(self.query[-5:]) == self.orig[-5:]
         assert self.query[10:20][5] == self.orig[10:20][5]
 
+    @testbase.supported('mssql')
+    def test_slice_mssql(self):
+        assert list(self.query[:10]) == self.orig[:10]
+        assert list(self.query[:10]) == self.orig[:10]
+
     def test_aggregate(self):
         assert self.query.count() == 100
         assert self.query.filter(foo.c.bar<30).min(foo.c.bar) == 0
@@ -59,10 +65,13 @@ class GenerativeQueryTest(PersistTest):
         # this one fails in mysql as the result comes back as a string
         assert self.query.filter(foo.c.bar<30).sum(foo.c.bar) == 435
 
-    @testbase.unsupported('postgres', 'mysql', 'firebird')
+    @testbase.unsupported('postgres', 'mysql', 'firebird', 'mssql')
     def test_aggregate_2(self):
-        # this one fails with postgres, the floating point comparison fails
-        assert self.query.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5
+        assert self.res.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5
+
+    @testbase.supported('postgres', 'mysql', 'firebird', 'mssql')
+    def test_aggregate_2_int(self):
+        assert int(self.res.filter(foo.c.bar<30).avg(foo.c.bar)) == 14
 
     def test_filter(self):
         assert self.query.count() == 100
index 52c0e37e6cd825c3a509892d59b3f7c7ff83e488..f5a4613c95c312324886d6de4575bdab8c9b9b54 100644 (file)
@@ -983,8 +983,14 @@ class LazyTest(MapperSuperTest):
         ))
         sess= create_session()
         q = sess.query(m)
-        l = q.select(limit=2, offset=1)
-        self.assert_result(l, User, *user_all_result[1:3])
+        
+        if db.engine.name == 'mssql':
+            l = q.select(limit=2)
+            self.assert_result(l, User, *user_all_result[:2])
+        else:        
+            l = q.select(limit=2, offset=1)
+            self.assert_result(l, User, *user_all_result[1:3])
+
         # use a union all to get a lot of rows to join against
         u2 = users.alias('u2')
         s = union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u')
@@ -1124,8 +1130,13 @@ class EagerTest(MapperSuperTest):
         sess = create_session()
         q = sess.query(m)
         
-        l = q.select(limit=2, offset=1)
-        self.assert_result(l, User, *user_all_result[1:3])
+        if db.engine.name == 'mssql':
+            l = q.select(limit=2)
+            self.assert_result(l, User, *user_all_result[:2])
+        else:        
+            l = q.select(limit=2, offset=1)
+            self.assert_result(l, User, *user_all_result[1:3])
+
         # this is an involved 3x union of the users table to get a lot of rows.
         # then see if the "distinct" works its way out.  you actually get the same
         # result with or without the distinct, just via less or more rows.
@@ -1156,8 +1167,9 @@ class EagerTest(MapperSuperTest):
         sess = create_session()
         q = sess.query(m)
         
-        l = q.select(q.join_to('orders'), order_by=desc(orders.c.user_id), limit=2, offset=1)
-        self.assert_result(l, User, *(user_all_result[2], user_all_result[0]))
+        if db.engine.name != 'mssql':
+            l = q.select(q.join_to('orders'), order_by=desc(orders.c.user_id), limit=2, offset=1)
+            self.assert_result(l, User, *(user_all_result[2], user_all_result[0]))
         
         l = q.select(q.join_to('addresses'), order_by=desc(addresses.c.email_address), limit=1, offset=0)
         self.assert_result(l, User, *(user_all_result[0],))
index f9ba9409d689e5624be6cf88fab7267610623a53..6683fa0d0467ffb1cc2316cb5fd5b3859691e109 100644 (file)
@@ -174,6 +174,14 @@ class QueryTest(PersistTest):
         r = self.users.select(offset=5, order_by=[self.users.c.user_id]).execute().fetchall()
         self.assert_(r==[(6, 'ralph'), (7, 'fido')])
         
+    @testbase.supported('mssql')
+    def testselectlimitoffset_mssql(self):
+        try:
+            r = self.users.select(limit=3, offset=2, order_by=[self.users.c.user_id]).execute().fetchall()
+            assert False # InvalidRequestError should have been raised
+        except exceptions.InvalidRequestError:
+            pass
+
     @testbase.unsupported('mysql')  
     def test_scalar_select(self):
         """test that scalar subqueries with labels get their type propigated to the result set."""