From: Paul Johnston Date: Wed, 13 Jun 2007 18:53:16 +0000 (+0000) Subject: Multiple MSSQL fixes; see ticket #581 X-Git-Tag: rel_0_3_9~83 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2c65ce75360c6018ecc6160d6ef11e23ae628553;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Multiple MSSQL fixes; see ticket #581 --- diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 1cadbd14d0..4336296dd9 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -99,14 +99,14 @@ class MSDateTime(sqltypes.DateTime): def get_col_spec(self): return "DATETIME" - def convert_bind_param(self, value, dialect): - if hasattr(value, "isoformat"): - #return value.isoformat(' ') - # isoformat() bings on apodbapi -- reported/suggested by Peter Buschman - return value.strftime('%Y-%m-%d %H:%M:%S') - else: - return value +class MSDate(sqltypes.Date): + def __init__(self, *a, **kw): + super(MSDate, self).__init__(False) + def get_col_spec(self): + return "SMALLDATETIME" + +class MSDateTime_adodbapi(MSDateTime): def convert_result_value(self, value, dialect): # adodbapi will return datetimes with empty time values as datetime.date() objects. # Promote them back to full datetime.datetime() @@ -114,23 +114,34 @@ class MSDateTime(sqltypes.DateTime): return datetime.datetime(value.year, value.month, value.day) return value -class MSDate(sqltypes.Date): - def __init__(self, *a, **kw): - super(MSDate, self).__init__(False) +class MSDateTime_pyodbc(MSDateTime): + def convert_bind_param(self, value, dialect): + if value and not hasattr(value, 'second'): + return datetime.datetime(value.year, value.month, value.day) + else: + return value - def get_col_spec(self): - return "SMALLDATETIME" - +class MSDate_pyodbc(MSDate): def convert_bind_param(self, value, dialect): - if value and hasattr(value, "isoformat"): - return value.strftime('%Y-%m-%d %H:%M') - return value + if value and not hasattr(value, 'second'): + return datetime.datetime(value.year, value.month, value.day) + else: + return value + def convert_result_value(self, value, dialect): + # pyodbc returns SMALLDATETIME values as datetime.datetime(). truncate it back to datetime.date() + if value and hasattr(value, 'second'): + return value.date() + else: + return value + +class MSDate_pymssql(MSDate): def convert_result_value(self, value, dialect): # pymssql will return SMALLDATETIME values as datetime.datetime(), truncate it back to datetime.date() if value and hasattr(value, 'second'): return value.date() - return value + else: + return value class MSText(sqltypes.TEXT): def get_col_spec(self): @@ -143,7 +154,7 @@ class MSString(sqltypes.String): def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} -class MSNVarchar(MSString): +class MSNVarchar(sqltypes.Unicode): def get_col_spec(self): if self.length: return "NVARCHAR(%(length)s)" % {'length' : self.length} @@ -191,6 +202,10 @@ class MSBoolean(sqltypes.Boolean): else: return value and True or False +class MSTimeStamp(sqltypes.TIMESTAMP): + def get_col_spec(self): + return "TIMESTAMP" + def descriptor(): return {'name':'mssql', 'description':'MSSQL', @@ -240,7 +255,7 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): if self.IINSERT: # TODO: quoting rules for table name here ? - self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.name) + self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.compiled.statement.table.fullname) super(MSSQLExecutionContext, self).pre_exec() @@ -253,7 +268,7 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): if self.compiled.isinsert: if self.IINSERT: # TODO: quoting rules for table name here ? - self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.name) + self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.fullname) self.IINSERT = False elif self.HASIDENT: if self.dialect.use_scope_identity: @@ -294,6 +309,7 @@ class MSSQLDialect(ansisql.ANSIDialect): sqltypes.TEXT : MSText, sqltypes.CHAR: MSChar, sqltypes.NCHAR: MSNChar, + sqltypes.TIMESTAMP: MSTimeStamp, } ischema_names = { @@ -314,7 +330,8 @@ class MSSQLDialect(ansisql.ANSIDialect): 'binary' : MSBinary, 'bit': MSBoolean, 'real' : MSFloat, - 'image' : MSBinary + 'image' : MSBinary, + 'timestamp': MSTimeStamp, } def __new__(cls, dbapi=None, *args, **kwargs): @@ -330,7 +347,7 @@ class MSSQLDialect(ansisql.ANSIDialect): super(MSSQLDialect, self).__init__(**params) self.auto_identity_insert = auto_identity_insert self.text_as_varchar = False - self.use_scope_identity = True + self.use_scope_identity = False self.set_default_schema_name("dbo") def dbapi(cls, module_name=None): @@ -570,6 +587,16 @@ class MSSQLDialect_pymssql(MSSQLDialect): return module import_dbapi = classmethod(import_dbapi) + colspecs = MSSQLDialect.colspecs.copy() + colspecs[sqltypes.Date] = MSDate_pymssql + + ischema_names = MSSQLDialect.ischema_names.copy() + ischema_names['smalldatetime'] = MSDate_pymssql + + def __init__(self, **params): + super(MSSQLDialect_pymssql, self).__init__(**params) + self.use_scope_identity = True + def supports_sane_rowcount(self): return True @@ -641,12 +668,21 @@ class MSSQLDialect_pyodbc(MSSQLDialect): colspecs = MSSQLDialect.colspecs.copy() colspecs[sqltypes.Unicode] = AdoMSNVarchar + colspecs[sqltypes.Date] = MSDate_pyodbc + colspecs[sqltypes.DateTime] = MSDateTime_pyodbc + ischema_names = MSSQLDialect.ischema_names.copy() ischema_names['nvarchar'] = AdoMSNVarchar + ischema_names['smalldatetime'] = MSDate_pyodbc + ischema_names['datetime'] = MSDateTime_pyodbc def supports_sane_rowcount(self): return False + def supports_unicode_statements(self): + """indicate whether the DBAPI can receive SQL statements as Python unicode strings""" + return True + def make_connect_string(self, keys): connectors = ["Driver={SQL Server}"] connectors.append("Server=%s" % keys.get("host")) @@ -674,12 +710,19 @@ class MSSQLDialect_adodbapi(MSSQLDialect): colspecs = MSSQLDialect.colspecs.copy() colspecs[sqltypes.Unicode] = AdoMSNVarchar + colspecs[sqltypes.DateTime] = MSDateTime_adodbapi + ischema_names = MSSQLDialect.ischema_names.copy() ischema_names['nvarchar'] = AdoMSNVarchar + ischema_names['datetime'] = MSDateTime_adodbapi def supports_sane_rowcount(self): return True + def supports_unicode_statements(self): + """indicate whether the DBAPI can receive SQL statements as Python unicode strings""" + return True + def make_connect_string(self, keys): connectors = ["Provider=SQLOLEDB"] if 'port' in keys: diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 41b61d4afb..489e9d59e2 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -2686,6 +2686,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause): self.is_compound = True self.is_where = False self.is_scalar = False + self.is_subquery = False self.selects = selects diff --git a/test/sql/query.py b/test/sql/query.py index d788544c09..6e43f8779a 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -402,6 +402,19 @@ class QueryTest(PersistTest): con.execute("""drop trigger paj""") meta.drop_all() + @testbase.supported('mssql') + def test_insertid_schema(self): + meta = BoundMetaData(testbase.db) + con = testbase.db.connect() + con.execute('create schema paj') + tbl = Table('test', meta, Column('id', Integer, primary_key=True), schema='paj') + tbl.create() + try: + tbl.insert().execute({'id':1}) + finally: + tbl.drop() + con.execute('drop schema paj') + class CompoundTest(PersistTest): """test compound statements like UNION, INTERSECT, particularly their ability to nest on diff --git a/test/sql/rowcount.py b/test/sql/rowcount.py index 05d0f21105..95cab898c3 100644 --- a/test/sql/rowcount.py +++ b/test/sql/rowcount.py @@ -31,7 +31,7 @@ class FoundRowsTest(testbase.AssertMixin): i.execute(*[{'name':n, 'department':d} for n, d in data]) def tearDown(self): employees_table.delete().execute() - + def tearDownAll(self): employees_table.drop() @@ -45,23 +45,26 @@ class FoundRowsTest(testbase.AssertMixin): # WHERE matches 3, 3 rows changed department = employees_table.c.department r = employees_table.update(department=='C').execute(department='Z') - assert r.rowcount == 3 - + if testbase.db.dialect.supports_sane_rowcount(): + assert r.rowcount == 3 + def test_update_rowcount2(self): # WHERE matches 3, 0 rows changed department = employees_table.c.department r = employees_table.update(department=='C').execute(department='C') - assert r.rowcount == 3 - + if testbase.db.dialect.supports_sane_rowcount(): + assert r.rowcount == 3 + def test_delete_rowcount(self): # WHERE matches 3, 3 rows deleted department = employees_table.c.department r = employees_table.delete(department=='C').execute() - assert r.rowcount == 3 + if testbase.db.dialect.supports_sane_rowcount(): + assert r.rowcount == 3 if __name__ == '__main__': testbase.main() - + diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index acf21b917e..b2b747a334 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -190,6 +190,11 @@ class UnicodeTest(AssertMixin): finally: db.engine.dialect.convert_unicode = prev_unicode + def testlength(self): + """checks the database correctly understands the length of a unicode string""" + teststr = u'aaa\x1234' + self.assert_(db.func.length(teststr).scalar() == len(teststr)) + class BinaryTest(AssertMixin): def setUpAll(self): global binary_table @@ -313,6 +318,24 @@ class DateTest(AssertMixin): #x = db.text("select * from query_users_with_date where user_datetime=:date", bindparams=[bindparam('date', )]).execute(date=datetime.datetime(2005, 11, 10, 11, 52, 35)).fetchall() #print repr(x) + @testbase.unsupported('sqlite') + def testdate2(self): + t = Table('testdate', testbase.metadata, Column('id', Integer, primary_key=True), + Column('adate', Date), Column('adatetime', DateTime)) + t.create() + try: + d1 = datetime.date(2007, 10, 30) + t.insert().execute(adate=d1, adatetime=d1) + d2 = datetime.datetime(2007, 10, 30) + t.insert().execute(adate=d2, adatetime=d2) + + x = t.select().execute().fetchall()[0] + self.assert_(x.adate.__class__ == datetime.date) + self.assert_(x.adatetime.__class__ == datetime.datetime) + + finally: + t.drop() + class TimezoneTest(AssertMixin): """test timezone-aware datetimes. psycopg will return a datetime with a tzinfo attached to it, if postgres returns it. python then will not let you compare a datetime with a tzinfo to a datetime