From 30b20e3c563d10935b1c85f43d0b7b3054f81415 Mon Sep 17 00:00:00 2001 From: Rick Morrison Date: Thu, 15 Mar 2007 02:31:15 +0000 Subject: [PATCH] MSSQL now passes still more unit tests [ticket:481] Fix to null FLOAT fields in mssql-trusted.patch MSSQL: LIMIT with OFFSET now raises an error MSSQL: can now specify Windows authorization MSSQL: ignores seconds on DATE columns (DATE fix, part 1) --- CHANGES | 11 ++++++ lib/sqlalchemy/databases/mssql.py | 66 +++++++++++++++++++++---------- test/engine/pool.py | 2 +- test/engine/reflection.py | 1 + test/ext/selectresults.py | 15 +++++-- test/orm/generative.py | 15 +++++-- test/orm/mapper.py | 24 ++++++++--- test/sql/query.py | 8 ++++ 8 files changed, 109 insertions(+), 33 deletions(-) diff --git a/CHANGES b/CHANGES index 745922297b..9f72facd3a 100644 --- a/CHANGES +++ b/CHANGES @@ -48,6 +48,7 @@ - the "else_" parameter to the case statement now properly works when set to zero. + - oracle: - got binary working for any size input ! cx_oracle works fine, it was my fault as BINARY was being passed and not BLOB for @@ -141,6 +142,16 @@ - query() method is added by assignmapper. this helps with navigating to all the new generative methods on Query. + +- ms-sql: + - removed seconds input on DATE column types (probably + should remove the time altogether) + + - null values in float fields no longer raise errors + + - LIMIT with OFFSET now raises an error (MS-SQL has no OFFSET support) + + 0.3.5 - sql: diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 8c3c71f6ed..60d52d1819 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -61,8 +61,16 @@ def use_adodbapi(): # ADODBAPI has a non-standard Connection method connect = dbmodule.Connection def make_connect_string(keys): - return [["Provider=SQLOLEDB;Data Source=%s;User Id=%s;Password=%s;Initial Catalog=%s" % ( - keys.get("host"), keys.get("user"), keys.get("password", ""), keys.get("database"))], {}] + connectors = ["Provider=SQLOLEDB"] + connectors.append ("Data Source=%s" % keys.get("host")) + connectors.append ("Initial Catalog=%s" % keys.get("database")) + user = keys.get("user") + if user: + connectors.append("User Id=%s" % user) + connectors.append("Password=%s" % keys.get("password", "")) + else: + connectors.append("Integrated Security=SSPI") + return [[";".join (connectors)], {}] sane_rowcount = True dialect = MSSQLDialect colspecs[sqltypes.Unicode] = AdoMSUnicode @@ -91,8 +99,16 @@ def use_pyodbc(): import pyodbc as dbmodule connect = dbmodule.connect def make_connect_string(keys): - return [["Driver={SQL Server};Server=%s;UID=%s;PWD=%s;Database=%s" % ( - keys.get("host"), keys.get("user"), keys.get("password", ""), keys.get("database"))], {}] + connectors = ["Driver={SQL Server}"] + connectors.append("Server=%s" % keys.get("host")) + connectors.append("Database=%s" % keys.get("database")) + user = keys.get("user") + if user: + connectors.append("UID=%s" % user) + connectors.append("PWD=%s" % keys.get("password", "")) + else: + connectors.append ("TrustedConnection=Yes") + return [[";".join (connectors)], {}] do_commit = True sane_rowcount = False dialect = MSSQLDialect @@ -150,7 +166,7 @@ class MSFloat(sqltypes.Float): def convert_bind_param(self, value, dialect): """By converting to string, we can use Decimal types round-trip.""" - return str(value) + return value and str(value) or None class MSInteger(sqltypes.Integer): def get_col_spec(self): @@ -195,7 +211,7 @@ class MSDate(sqltypes.Date): def convert_bind_param(self, value, dialect): if value and hasattr(value, "isoformat"): - return value.strftime('%Y-%m-%d %H:%M:%S') + return value.strftime('%Y-%m-%d %H:%M') return value def convert_result_value(self, value, dialect): @@ -327,26 +343,29 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): def __init__(self, dialect): self.IINSERT = self.HASIDENT = False super(MSSQLExecutionContext, self).__init__(dialect) - + + def _has_implicit_sequence(self, column): + if column.primary_key and column.autoincrement: + if isinstance(column.type, sqltypes.Integer) and not column.foreign_key: + if column.default is None or (isinstance(column.default, schema.Sequence) and \ + column.default.optional): + return True + return False + def pre_exec(self, engine, proxy, compiled, parameters, **kwargs): """MS-SQL has a special mode for inserting non-NULL values into IDENTITY columns. Activate it if the feature is turned on and needed. """ - if getattr(compiled, "isinsert", False): tbl = compiled.statement.table - if not hasattr(tbl, 'has_sequence'): + if not hasattr(tbl, 'has_sequence'): + tbl.has_sequence = False for column in tbl.c: - if column.primary_key and column.autoincrement and \ - isinstance(column.type, sqltypes.Integer) and not column.foreign_key: - if column.default is None or (isinstance(column.default, schema.Sequence) and \ - column.default.optional): - tbl.has_sequence = column - break - else: - tbl.has_sequence = False + if getattr(column, 'sequence', False) or self._has_implicit_sequence(column): + tbl.has_sequence = column + break self.HASIDENT = bool(tbl.has_sequence) if engine.dialect.auto_identity_insert and self.HASIDENT: @@ -520,6 +539,10 @@ class MSSQLDialect(ansisql.ANSIDialect): row[columns.c.column_default] ) + # cope with varchar(max) + if charlen == -1: + charlen = None + args = [] for a in (charlen, numericprec, numericscale): if a is not None: @@ -644,12 +667,14 @@ class MSSQLCompiler(ansisql.ANSICompiler): def visit_select_precolumns(self, select): """ MS-SQL puts TOP, it's version of LIMIT here """ s = select.distinct and "DISTINCT " or "" - if (select.limit): + if select.limit: s += "TOP %s " % (select.limit,) + if select.offset: + raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset') return s - def limit_clause(self, select): - # Limit in mssql is after the select keyword; MSsql has no support for offset + def limit_clause(self, select): + # Limit in mssql is after the select keyword return "" def visit_table(self, table): @@ -744,3 +769,4 @@ class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): use_default() + diff --git a/test/engine/pool.py b/test/engine/pool.py index 08df106ce2..db97ea6f8d 100644 --- a/test/engine/pool.py +++ b/test/engine/pool.py @@ -162,7 +162,7 @@ class PoolTest(PersistTest): c2 = p.connect() assert id(c2.connection) == c_id c2.close() - time.sleep(3) + time.sleep(4) c3= p.connect() assert id(c3.connection) != c_id diff --git a/test/engine/reflection.py b/test/engine/reflection.py index 51a3d35c67..62cd92b6e6 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -267,6 +267,7 @@ class ReflectionTest(PersistTest): testbase.db.execute("drop table django_admin_log") testbase.db.execute("drop table django_content_type") + @testbase.unsupported('mssql') def testmultipk(self): """test that creating a table checks for a sequence before creating it""" meta = BoundMetaData(testbase.db) diff --git a/test/ext/selectresults.py b/test/ext/selectresults.py index 88476c9cc0..8df416be94 100644 --- a/test/ext/selectresults.py +++ b/test/ext/selectresults.py @@ -39,7 +39,8 @@ class SelectResultsTest(PersistTest): res = self.query.select_by(range=5) assert res.order_by([Foo.c.bar])[0].bar == 5 assert res.order_by([desc(Foo.c.bar)])[0].bar == 95 - + + @testbase.unsupported('mssql') def test_slice(self): assert self.res[1] == self.orig[1] assert list(self.res[10:20]) == self.orig[10:20] @@ -50,6 +51,11 @@ class SelectResultsTest(PersistTest): assert list(self.res[-5:]) == self.orig[-5:] assert self.res[10:20][5] == self.orig[10:20][5] + @testbase.supported('mssql') + def test_slice_mssql(self): + assert list(self.res[:10]) == self.orig[:10] + assert list(self.res[:10]) == self.orig[:10] + def test_aggregate(self): assert self.res.count() == 100 assert self.res.filter(foo.c.bar<30).min(foo.c.bar) == 0 @@ -60,11 +66,14 @@ class SelectResultsTest(PersistTest): # this one fails in mysql as the result comes back as a string assert self.res.filter(foo.c.bar<30).sum(foo.c.bar) == 435 - @testbase.unsupported('postgres', 'mysql', 'firebird') + @testbase.unsupported('postgres', 'mysql', 'firebird', 'mssql') def test_aggregate_2(self): - # this one fails with postgres, the floating point comparison fails assert self.res.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5 + @testbase.supported('postgres', 'mysql', 'firebird', 'mssql') + def test_aggregate_2_int(self): + assert int(self.res.filter(foo.c.bar<30).avg(foo.c.bar)) == 14 + def test_filter(self): assert self.res.count() == 100 assert self.res.filter(Foo.c.bar < 30).count() == 30 diff --git a/test/orm/generative.py b/test/orm/generative.py index 37ce1dcc9b..b8c2a85e1b 100644 --- a/test/orm/generative.py +++ b/test/orm/generative.py @@ -39,6 +39,7 @@ class GenerativeQueryTest(PersistTest): assert res.order_by([Foo.c.bar])[0].bar == 5 assert res.order_by([desc(Foo.c.bar)])[0].bar == 95 + @testbase.unsupported('mssql') def test_slice(self): assert self.query[1] == self.orig[1] assert list(self.query[10:20]) == self.orig[10:20] @@ -49,6 +50,11 @@ class GenerativeQueryTest(PersistTest): assert list(self.query[-5:]) == self.orig[-5:] assert self.query[10:20][5] == self.orig[10:20][5] + @testbase.supported('mssql') + def test_slice_mssql(self): + assert list(self.query[:10]) == self.orig[:10] + assert list(self.query[:10]) == self.orig[:10] + def test_aggregate(self): assert self.query.count() == 100 assert self.query.filter(foo.c.bar<30).min(foo.c.bar) == 0 @@ -59,10 +65,13 @@ class GenerativeQueryTest(PersistTest): # this one fails in mysql as the result comes back as a string assert self.query.filter(foo.c.bar<30).sum(foo.c.bar) == 435 - @testbase.unsupported('postgres', 'mysql', 'firebird') + @testbase.unsupported('postgres', 'mysql', 'firebird', 'mssql') def test_aggregate_2(self): - # this one fails with postgres, the floating point comparison fails - assert self.query.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5 + assert self.res.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5 + + @testbase.supported('postgres', 'mysql', 'firebird', 'mssql') + def test_aggregate_2_int(self): + assert int(self.res.filter(foo.c.bar<30).avg(foo.c.bar)) == 14 def test_filter(self): assert self.query.count() == 100 diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 52c0e37e6c..f5a4613c95 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -983,8 +983,14 @@ class LazyTest(MapperSuperTest): )) sess= create_session() q = sess.query(m) - l = q.select(limit=2, offset=1) - self.assert_result(l, User, *user_all_result[1:3]) + + if db.engine.name == 'mssql': + l = q.select(limit=2) + self.assert_result(l, User, *user_all_result[:2]) + else: + l = q.select(limit=2, offset=1) + self.assert_result(l, User, *user_all_result[1:3]) + # use a union all to get a lot of rows to join against u2 = users.alias('u2') s = union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u') @@ -1124,8 +1130,13 @@ class EagerTest(MapperSuperTest): sess = create_session() q = sess.query(m) - l = q.select(limit=2, offset=1) - self.assert_result(l, User, *user_all_result[1:3]) + if db.engine.name == 'mssql': + l = q.select(limit=2) + self.assert_result(l, User, *user_all_result[:2]) + else: + l = q.select(limit=2, offset=1) + self.assert_result(l, User, *user_all_result[1:3]) + # this is an involved 3x union of the users table to get a lot of rows. # then see if the "distinct" works its way out. you actually get the same # result with or without the distinct, just via less or more rows. @@ -1156,8 +1167,9 @@ class EagerTest(MapperSuperTest): sess = create_session() q = sess.query(m) - l = q.select(q.join_to('orders'), order_by=desc(orders.c.user_id), limit=2, offset=1) - self.assert_result(l, User, *(user_all_result[2], user_all_result[0])) + if db.engine.name != 'mssql': + l = q.select(q.join_to('orders'), order_by=desc(orders.c.user_id), limit=2, offset=1) + self.assert_result(l, User, *(user_all_result[2], user_all_result[0])) l = q.select(q.join_to('addresses'), order_by=desc(addresses.c.email_address), limit=1, offset=0) self.assert_result(l, User, *(user_all_result[0],)) diff --git a/test/sql/query.py b/test/sql/query.py index f9ba9409d6..6683fa0d04 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -174,6 +174,14 @@ class QueryTest(PersistTest): r = self.users.select(offset=5, order_by=[self.users.c.user_id]).execute().fetchall() self.assert_(r==[(6, 'ralph'), (7, 'fido')]) + @testbase.supported('mssql') + def testselectlimitoffset_mssql(self): + try: + r = self.users.select(limit=3, offset=2, order_by=[self.users.c.user_id]).execute().fetchall() + assert False # InvalidRequestError should have been raised + except exceptions.InvalidRequestError: + pass + @testbase.unsupported('mysql') def test_scalar_select(self): """test that scalar subqueries with labels get their type propigated to the result set.""" -- 2.47.2