]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Multiple MSSQL fixes; see ticket #581
authorPaul Johnston <paj@pajhome.org.uk>
Wed, 13 Jun 2007 18:53:16 +0000 (18:53 +0000)
committerPaul Johnston <paj@pajhome.org.uk>
Wed, 13 Jun 2007 18:53:16 +0000 (18:53 +0000)
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/sql.py
test/sql/query.py
test/sql/rowcount.py
test/sql/testtypes.py

index 1cadbd14d01847fd634bab1c6b3c9e09d3f9071c..4336296dd90fa520d9615c8b287cbcfdbc7c5415 100644 (file)
@@ -99,14 +99,14 @@ class MSDateTime(sqltypes.DateTime):
     def get_col_spec(self):
         return "DATETIME"
 
-    def convert_bind_param(self, value, dialect):
-        if hasattr(value, "isoformat"):
-            #return value.isoformat(' ')
-            # isoformat() bings on apodbapi -- reported/suggested by Peter Buschman
-            return value.strftime('%Y-%m-%d %H:%M:%S')
-        else:
-            return value
+class MSDate(sqltypes.Date):
+    def __init__(self, *a, **kw):
+        super(MSDate, self).__init__(False)
 
+    def get_col_spec(self):
+        return "SMALLDATETIME"
+
+class MSDateTime_adodbapi(MSDateTime):
     def convert_result_value(self, value, dialect):
         # adodbapi will return datetimes with empty time values as datetime.date() objects.
         # Promote them back to full datetime.datetime()
@@ -114,23 +114,34 @@ class MSDateTime(sqltypes.DateTime):
             return datetime.datetime(value.year, value.month, value.day)
         return value
 
-class MSDate(sqltypes.Date):
-    def __init__(self, *a, **kw):
-        super(MSDate, self).__init__(False)
+class MSDateTime_pyodbc(MSDateTime):
+    def convert_bind_param(self, value, dialect):
+        if value and not hasattr(value, 'second'):
+            return datetime.datetime(value.year, value.month, value.day)
+        else:
+            return value
 
-    def get_col_spec(self):
-        return "SMALLDATETIME"
-    
+class MSDate_pyodbc(MSDate):
     def convert_bind_param(self, value, dialect):
-        if value and hasattr(value, "isoformat"):
-            return value.strftime('%Y-%m-%d %H:%M')
-        return value
+        if value and not hasattr(value, 'second'):
+            return datetime.datetime(value.year, value.month, value.day)
+        else:
+            return value
 
+    def convert_result_value(self, value, dialect):
+        # pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date()
+        if value and hasattr(value, 'second'):
+            return value.date()
+        else:
+            return value
+
+class MSDate_pymssql(MSDate):
     def convert_result_value(self, value, dialect):
         # pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date()
         if value and hasattr(value, 'second'):
             return value.date()
-        return value
+        else:
+            return value
 
 class MSText(sqltypes.TEXT):
     def get_col_spec(self):
@@ -143,7 +154,7 @@ class MSString(sqltypes.String):
     def get_col_spec(self):
         return "VARCHAR(%(length)s)" % {'length' : self.length}
 
-class MSNVarchar(MSString):
+class MSNVarchar(sqltypes.Unicode):
     def get_col_spec(self):
         if self.length:
             return "NVARCHAR(%(length)s)" % {'length' : self.length}
@@ -191,6 +202,10 @@ class MSBoolean(sqltypes.Boolean):
         else:
             return value and True or False
         
+class MSTimeStamp(sqltypes.TIMESTAMP):
+    def get_col_spec(self):
+        return "TIMESTAMP"
+        
 def descriptor():
     return {'name':'mssql',
     'description':'MSSQL',
@@ -240,7 +255,7 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
 
             if self.IINSERT:
                 # TODO: quoting rules for table name here ?
-                self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.name)
+                self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.fullname)
 
         super(MSSQLExecutionContext, self).pre_exec()
 
@@ -253,7 +268,7 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
         if self.compiled.isinsert:
             if self.IINSERT:
                 # TODO: quoting rules for table name here ?
-                self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.name)
+                self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.fullname)
                 self.IINSERT = False
             elif self.HASIDENT:
                 if self.dialect.use_scope_identity:
@@ -294,6 +309,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
         sqltypes.TEXT : MSText,
         sqltypes.CHAR: MSChar,
         sqltypes.NCHAR: MSNChar,
+        sqltypes.TIMESTAMP: MSTimeStamp,
     }
 
     ischema_names = {
@@ -314,7 +330,8 @@ class MSSQLDialect(ansisql.ANSIDialect):
         'binary' : MSBinary,
         'bit': MSBoolean,
         'real' : MSFloat,
-        'image' : MSBinary
+        'image' : MSBinary,
+        'timestamp': MSTimeStamp,
     }
 
     def __new__(cls, dbapi=None, *args, **kwargs):
@@ -330,7 +347,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
         super(MSSQLDialect, self).__init__(**params)
         self.auto_identity_insert = auto_identity_insert
         self.text_as_varchar = False
-        self.use_scope_identity = True
+        self.use_scope_identity = False
         self.set_default_schema_name("dbo")
 
     def dbapi(cls, module_name=None):
@@ -570,6 +587,16 @@ class MSSQLDialect_pymssql(MSSQLDialect):
         return module
     import_dbapi = classmethod(import_dbapi)
     
+    colspecs = MSSQLDialect.colspecs.copy()
+    colspecs[sqltypes.Date] = MSDate_pymssql
+
+    ischema_names = MSSQLDialect.ischema_names.copy()
+    ischema_names['smalldatetime'] = MSDate_pymssql
+
+    def __init__(self, **params):
+        super(MSSQLDialect_pymssql, self).__init__(**params)
+        self.use_scope_identity = True
+
     def supports_sane_rowcount(self):
         return True
 
@@ -641,12 +668,21 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
     
     colspecs = MSSQLDialect.colspecs.copy()
     colspecs[sqltypes.Unicode] = AdoMSNVarchar
+    colspecs[sqltypes.Date] = MSDate_pyodbc
+    colspecs[sqltypes.DateTime] = MSDateTime_pyodbc
+
     ischema_names = MSSQLDialect.ischema_names.copy()
     ischema_names['nvarchar'] = AdoMSNVarchar
+    ischema_names['smalldatetime'] = MSDate_pyodbc
+    ischema_names['datetime'] = MSDateTime_pyodbc
 
     def supports_sane_rowcount(self):
         return False
 
+    def supports_unicode_statements(self):
+        """indicate whether the DBAPI can receive SQL statements as Python unicode strings"""
+        return True
+
     def make_connect_string(self, keys):
         connectors = ["Driver={SQL Server}"]
         connectors.append("Server=%s" % keys.get("host"))
@@ -674,12 +710,19 @@ class MSSQLDialect_adodbapi(MSSQLDialect):
 
     colspecs = MSSQLDialect.colspecs.copy()
     colspecs[sqltypes.Unicode] = AdoMSNVarchar
+    colspecs[sqltypes.DateTime] = MSDateTime_adodbapi
+
     ischema_names = MSSQLDialect.ischema_names.copy()
     ischema_names['nvarchar'] = AdoMSNVarchar
+    ischema_names['datetime'] = MSDateTime_adodbapi
 
     def supports_sane_rowcount(self):
         return True
 
+    def supports_unicode_statements(self):
+        """indicate whether the DBAPI can receive SQL statements as Python unicode strings"""
+        return True
+
     def make_connect_string(self, keys):
         connectors = ["Provider=SQLOLEDB"]
         if 'port' in keys:
index 41b61d4afb65985487d0d931ab4d2be616516d0e..489e9d59e2fa84ffa655ee909307f7f3f02a6b60 100644 (file)
@@ -2686,6 +2686,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
         self.is_compound = True
         self.is_where = False
         self.is_scalar = False
+        self.is_subquery = False
 
         self.selects = selects
 
index d788544c096646a6404ef0f1cdcbd39a978633f8..6e43f8779a7295713404cf1957f1fc3e94c9bc29 100644 (file)
@@ -402,6 +402,19 @@ class QueryTest(PersistTest):
             con.execute("""drop trigger paj""")
             meta.drop_all()
 
+    @testbase.supported('mssql')
+    def test_insertid_schema(self):
+        meta = BoundMetaData(testbase.db)
+        con = testbase.db.connect()
+        con.execute('create schema paj')
+        tbl = Table('test', meta, Column('id', Integer, primary_key=True), schema='paj')
+        tbl.create()        
+        try:
+            tbl.insert().execute({'id':1})        
+        finally:
+            tbl.drop()
+            con.execute('drop schema paj')
+        
 
 class CompoundTest(PersistTest):
     """test compound statements like UNION, INTERSECT, particularly their ability to nest on
index 05d0f21105a76f7512c1e9d8c7f94ee23d823ba6..95cab898c3cd223deef184ebe2b2ff2132774a17 100644 (file)
@@ -31,7 +31,7 @@ class FoundRowsTest(testbase.AssertMixin):
         i.execute(*[{'name':n, 'department':d} for n, d in data])
     def tearDown(self):
         employees_table.delete().execute()
-        
+
     def tearDownAll(self):
         employees_table.drop()
 
@@ -45,23 +45,26 @@ class FoundRowsTest(testbase.AssertMixin):
         # WHERE matches 3, 3 rows changed
         department = employees_table.c.department
         r = employees_table.update(department=='C').execute(department='Z')
-        assert r.rowcount == 3
-        
+        if testbase.db.dialect.supports_sane_rowcount():
+            assert r.rowcount == 3
+
     def test_update_rowcount2(self):
         # WHERE matches 3, 0 rows changed
         department = employees_table.c.department
         r = employees_table.update(department=='C').execute(department='C')
-        assert r.rowcount == 3
-        
+        if testbase.db.dialect.supports_sane_rowcount():
+            assert r.rowcount == 3
+
     def test_delete_rowcount(self):
         # WHERE matches 3, 3 rows deleted
         department = employees_table.c.department
         r = employees_table.delete(department=='C').execute()
-        assert r.rowcount == 3
+        if testbase.db.dialect.supports_sane_rowcount():
+            assert r.rowcount == 3
 
 if __name__ == '__main__':
     testbase.main()
-    
+
 
 
 
index acf21b917ea6e07ddd7fa265aebc45e6c6db73cd..b2b747a334afabbb2c9ce0ddb679a9674ef05378 100644 (file)
@@ -190,6 +190,11 @@ class UnicodeTest(AssertMixin):
         finally:
             db.engine.dialect.convert_unicode = prev_unicode
 
+    def testlength(self):
+        """checks the database correctly understands the length of a unicode string"""
+        teststr = u'aaa\x1234'
+        self.assert_(db.func.length(teststr).scalar() == len(teststr))
+  
 class BinaryTest(AssertMixin):
     def setUpAll(self):
         global binary_table
@@ -313,6 +318,24 @@ class DateTest(AssertMixin):
         #x = db.text("select * from query_users_with_date where user_datetime=:date", bindparams=[bindparam('date', )]).execute(date=datetime.datetime(2005, 11, 10, 11, 52, 35)).fetchall()
         #print repr(x)
 
+    @testbase.unsupported('sqlite')
+    def testdate2(self):
+        t = Table('testdate', testbase.metadata, Column('id', Integer, primary_key=True),
+                Column('adate', Date), Column('adatetime', DateTime))
+        t.create()
+        try:
+            d1 = datetime.date(2007, 10, 30)
+            t.insert().execute(adate=d1, adatetime=d1)
+            d2 = datetime.datetime(2007, 10, 30)
+            t.insert().execute(adate=d2, adatetime=d2)
+
+            x = t.select().execute().fetchall()[0]
+            self.assert_(x.adate.__class__ == datetime.date)
+            self.assert_(x.adatetime.__class__ == datetime.datetime)
+
+        finally:
+            t.drop()
+
 class TimezoneTest(AssertMixin):
     """test timezone-aware datetimes.  psycopg will return a datetime with a tzinfo attached to it,
     if postgres returns it.  python then will not let you compare a datetime with a tzinfo to a datetime