From 0b90751a1ca1a16f9aab4fa6665127a4ccdef793 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 23 Jan 2007 21:14:54 +0000 Subject: [PATCH] test patches from [ticket:422] --- test/engine/transaction.py | 2 +- test/orm/session.py | 15 +++++++++++++++ test/tables.py | 8 ++++++++ test/testbase.py | 17 ++++++++++------- 4 files changed, 34 insertions(+), 8 deletions(-) diff --git a/test/engine/transaction.py b/test/engine/transaction.py index c89bf4b145..16667c1e05 100644 --- a/test/engine/transaction.py +++ b/test/engine/transaction.py @@ -144,7 +144,7 @@ class AutoRollbackTest(testbase.PersistTest): class TLTransactionTest(testbase.PersistTest): def setUpAll(self): global users, metadata, tlengine - tlengine = create_engine(testbase.db_uri, strategy='threadlocal') + tlengine = create_engine(testbase.db_uri, strategy='threadlocal', **testbase.db_opts) metadata = MetaData() users = Table('query_users', metadata, Column('user_id', INT, primary_key = True), diff --git a/test/orm/session.py b/test/orm/session.py index aab50ac475..148bcd2b5d 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -161,6 +161,21 @@ class SessionTest(AssertMixin): s.clear() assert s.query(Address).selectone().address_id == a.address_id assert s.query(User).selectfirst() is None + + def test_fetchid(self): + # this is necessary to ensure the test fails on old versions of mssql + if hasattr(autoseq.columns['autoseq_id'], 'sequence'): + del autoseq.columns['autoseq_id'].sequence + + mapper(Autoseq, autoseq) + s = create_session() + u = Autoseq() + s.save(u) + s.flush() + assert u.autoseq_id is not None + s.clear() + + class OrphanDeletionTest(AssertMixin): diff --git a/test/tables.py b/test/tables.py index f3b78a125c..f9c131653e 100644 --- a/test/tables.py +++ b/test/tables.py @@ -53,6 +53,11 @@ itemkeywords = Table('itemkeywords', metadata, # Column('foo', Boolean, default=True) ) +autoseq = Table('autoseq', metadata, + Column('autoseq_id', Integer, primary_key = True), + Column('name', String) +) + def create(): metadata.create_all() def drop(): @@ -165,6 +170,9 @@ class Keyword(object): def __repr__(self): return "Keyword: %s/%s" % (repr(getattr(self, 'keyword_id', None)),repr(self.name)) +class Autoseq(object): + def __init__(self): + self.autoseq_id = None user_result = [{'user_id' : 7}, {'user_id' : 8}, {'user_id' : 9}] diff --git a/test/testbase.py b/test/testbase.py index 360a02b320..e96b30dcbe 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -15,6 +15,7 @@ import optparse db = None metadata = None db_uri = None +db_opts = {} echo = True # redefine sys.stdout so all those print statements go to the echo func @@ -32,7 +33,7 @@ def echo_text(text): def parse_argv(): # we are using the unittest main runner, so we are just popping out the # arguments we need instead of using our own getopt type of thing - global db, db_uri, metadata + global db, db_uri, db_opts, metadata DBTYPE = 'sqlite' PROXY = False @@ -55,11 +56,13 @@ def parse_argv(): if options.dburi: db_uri = param = options.dburi + DBTYPE = db_uri[:db_uri.index(':')] elif options.db: DBTYPE = param = options.db - - opts = {} + if DBTYPE == 'mssql': + db_opts['auto_identity_insert'] = True + if (None == db_uri): if DBTYPE == 'sqlite': db_uri = 'sqlite:///:memory:' @@ -73,7 +76,7 @@ def parse_argv(): db_uri = 'oracle://scott:tiger@127.0.0.1:1521' elif DBTYPE == 'oracle8': db_uri = 'oracle://scott:tiger@127.0.0.1:1521' - opts = {'use_ansi':False, 'auto_setinputsizes':True} + db_opts = {'use_ansi':False, 'auto_setinputsizes':True} elif DBTYPE == 'mssql': db_uri = 'mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test' elif DBTYPE == 'firebird': @@ -96,11 +99,11 @@ def parse_argv(): with_coverage = options.coverage if options.enginestrategy is not None: - opts['strategy'] = options.enginestrategy + db_opts['strategy'] = options.enginestrategy if options.mockpool: - db = engine.create_engine(db_uri, poolclass=pool.AssertionPool, **opts) + db = engine.create_engine(db_uri, poolclass=pool.AssertionPool, **db_opts) else: - db = engine.create_engine(db_uri, **opts) + db = engine.create_engine(db_uri, **db_opts) db = EngineAssert(db) import logging -- 2.47.2