]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
test fixes
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Jan 2009 22:29:32 +0000 (22:29 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Jan 2009 22:29:32 +0000 (22:29 +0000)
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/orm/query.py
test/dialect/sqlite.py
test/engine/parseconnect.py
test/ext/compiler.py
test/sql/select.py

index b763997d42ce243f8e2fd821c9bfd4ace5824a28..75409c10f7e167001c65241b5b3ee7a15e899976 100644 (file)
@@ -109,6 +109,8 @@ class DefaultEngineStrategy(EngineStrategy):
             if k in kwargs:
                 engine_args[k] = kwargs.pop(k)
 
+        _initialize = kwargs.pop('_initialize', True)
+        
         # all kwargs should be consumed
         if kwargs:
             raise TypeError(
@@ -121,11 +123,13 @@ class DefaultEngineStrategy(EngineStrategy):
                                     engineclass.__name__))
                                     
         engine = engineclass(pool, dialect, u, **engine_args)
-        conn = engine.connect()
-        try:
-            dialect.initialize(conn)
-        finally:
-            conn.close()
+        
+        if _initialize:
+            conn = engine.connect()
+            try:
+                dialect.initialize(conn)
+            finally:
+                conn.close()
         return engine
 
     def pool_threadlocal(self):
index c5d0f72695f4cb3018ad00a1e2085329fad8ff70..e85ecaa58c7e52ddb62f9d727fbc1752e2dab4c3 100644 (file)
@@ -1552,7 +1552,7 @@ class Query(object):
         if synchronize_session == 'evaluate':
             try:
                 evaluator_compiler = evaluator.EvaluatorCompiler()
-                eval_condition = evaluator_compiler.process(self.whereclause)
+                eval_condition = evaluator_compiler.process(self.whereclause or expression._Null)
             except evaluator.UnevaluatableError:
                 raise sa_exc.InvalidRequestError("Could not evaluate current criteria in Python.  "
                         "Specify 'fetch' or False for the synchronize_session parameter.")
@@ -1645,7 +1645,7 @@ class Query(object):
         if synchronize_session == 'evaluate':
             try:
                 evaluator_compiler = evaluator.EvaluatorCompiler()
-                eval_condition = evaluator_compiler.process(self.whereclause)
+                eval_condition = evaluator_compiler.process(self.whereclause or expression._Null)
 
                 value_evaluators = {}
                 for key,value in values.items():
index 29beec8d355f1f6266f006ecebbd079dab03fb7c..3a33f58edeb2df471de0a4a420023c8d1d5ed139 100644 (file)
@@ -47,42 +47,36 @@ class TestTypes(TestBase, AssertsExecutionResults):
             bindproc = t.dialect_impl(dialect).bind_processor(dialect)
             assert not bindproc or isinstance(bindproc(u"some string"), unicode)
         
-    @testing.uses_deprecated('Using String type with no length')
     def test_type_reflection(self):
         # (ask_for, roundtripped_as_if_different)
-        specs = [( String(), pysqlite_dialect.SLString(), ),
-                 ( String(1), pysqlite_dialect.SLString(1), ),
-                 ( String(3), pysqlite_dialect.SLString(3), ),
-                 ( Text(), pysqlite_dialect.SLText(), ),
-                 ( Unicode(), pysqlite_dialect.SLString(), ),
-                 ( Unicode(1), pysqlite_dialect.SLString(1), ),
-                 ( Unicode(3), pysqlite_dialect.SLString(3), ),
-                 ( UnicodeText(), pysqlite_dialect.SLText(), ),
-                 ( CLOB, pysqlite_dialect.SLText(), ),
-                 ( pysqlite_dialect.SLChar(1), ),
-                 ( CHAR(3), pysqlite_dialect.SLChar(3), ),
-                 ( NCHAR(2), pysqlite_dialect.SLChar(2), ),
-                 ( NUMERIC, sqlite.SLNumeric(), ),
-                 ( NUMERIC(10,2), sqlite.SLNumeric(10,2), ),
-                 ( Numeric, sqlite.SLNumeric(), ),
-                 ( Numeric(10, 2), sqlite.SLNumeric(10, 2), ),
-                 ( DECIMAL, sqlite.SLNumeric(), ),
-                 ( DECIMAL(10, 2), sqlite.SLNumeric(10, 2), ),
-                 ( Float, sqlite.SLNumeric(), ),
-                 ( sqlite.SLNumeric(), ),
-                 ( TIMESTAMP, sqlite.SLDateTime(), ),
-                 ( DATETIME, sqlite.SLDateTime(), ),
-                 ( DateTime, sqlite.SLDateTime(), ),
-                 ( sqlite.SLDateTime(), ),
-                 ( DATE, sqlite.SLDate(), ),
-                 ( Date, sqlite.SLDate(), ),
-                 ( sqlite.SLDate(), ),
-                 ( TIME, sqlite.SLTime(), ),
-                 ( Time, sqlite.SLTime(), ),
-                 ( sqlite.SLTime(), ),
-                 ( BOOLEAN, sqlite.SLBoolean(), ),
-                 ( Boolean, sqlite.SLBoolean(), ),
-                 ( sqlite.SLBoolean(), ),
+        specs = [( String(), String(), ),
+                 ( String(1), String(1), ),
+                 ( String(3), String(3), ),
+                 ( Text(), Text(), ),
+                 ( Unicode(), String(), ),
+                 ( Unicode(1), String(1), ),
+                 ( Unicode(3), String(3), ),
+                 ( UnicodeText(), Text(), ),
+                 ( CHAR(1), ),
+                 ( CHAR(3), CHAR(3), ),
+                 ( NUMERIC, NUMERIC(), ),
+                 ( NUMERIC(10,2), NUMERIC(10,2), ),
+                 ( Numeric, NUMERIC(), ),
+                 ( Numeric(10, 2), NUMERIC(10, 2), ),
+                 ( DECIMAL, DECIMAL(), ),
+                 ( DECIMAL(10, 2), DECIMAL(10, 2), ),
+                 ( Float, Float(), ),
+                 ( NUMERIC(), ),
+                 ( TIMESTAMP, TIMESTAMP(), ),
+                 ( DATETIME, DATETIME(), ),
+                 ( DateTime, DateTime(), ),
+                 ( DateTime(), ),
+                 ( DATE, DATE(), ),
+                 ( Date, Date(), ),
+                 ( TIME, TIME(), ),
+                 ( Time, Time(), ),
+                 ( BOOLEAN, BOOLEAN(), ),
+                 ( Boolean, Boolean(), ),
                  ]
         columns = [Column('c%i' % (i + 1), t[0]) for i, t in enumerate(specs)]
 
@@ -101,7 +95,7 @@ class TestTypes(TestBase, AssertsExecutionResults):
                 expected = [len(c) > 1 and c[1] or c[0] for c in specs]
                 for table in rt, rv:
                     for i, reflected in enumerate(table.c):
-                        assert isinstance(reflected.type, type(expected[i])), type(expected[i])
+                        assert isinstance(reflected.type, type(expected[i])), "%d: %r" % (i, type(expected[i]))
             finally:
                 db.execute('DROP VIEW types_v')
         finally:
index c82ca6d58d38568a468523dd46232e21127465cd..4a6ca90d1444e249e7aa29dc0dc3869693ffa19b 100644 (file)
@@ -123,80 +123,43 @@ pool_timeout=10
 
     def test_recycle(self):
         dbapi = MockDBAPI(foober=12, lala=18, hoho={'this':'dict'}, fooz='somevalue')
-        e = create_engine('postgres://', pool_recycle=472, module=dbapi)
+        e = create_engine('postgres://', pool_recycle=472, module=dbapi, _initialize=False)
         assert e.pool._recycle == 472
 
     def test_badargs(self):
-        # good arg, use MockDBAPI to prevent oracle import errors
-        e = create_engine('oracle://', use_ansi=True, module=MockDBAPI())
-
-        try:
-            e = create_engine("foobar://", module=MockDBAPI())
-            assert False
-        except ImportError:
-            assert True
+        self.assertRaises(ImportError, create_engine, "foobar://", module=MockDBAPI())
 
         # bad arg
-        try:
-            e = create_engine('postgres://', use_ansi=True, module=MockDBAPI())
-            assert False
-        except TypeError:
-            assert True
+        self.assertRaises(TypeError, create_engine, 'postgres://', use_ansi=True, module=MockDBAPI())
 
         # bad arg
-        try:
-            e = create_engine('oracle://', lala=5, use_ansi=True, module=MockDBAPI())
-            assert False
-        except TypeError:
-            assert True
-
-        try:
-            e = create_engine('postgres://', lala=5, module=MockDBAPI())
-            assert False
-        except TypeError:
-            assert True
+        self.assertRaises(TypeError, create_engine, 'oracle://', lala=5, use_ansi=True, module=MockDBAPI())
 
-        try:
-            e = create_engine('sqlite://', lala=5)
-            assert False
-        except TypeError:
-            assert True
+        self.assertRaises(TypeError, create_engine, 'postgres://', lala=5, module=MockDBAPI())
 
-        try:
-            e = create_engine('mysql://', use_unicode=True, module=MockDBAPI())
-            assert False
-        except TypeError:
-            assert True
+        self.assertRaises(TypeError, create_engine,'sqlite://', lala=5)
 
-        try:
-            # sqlite uses SingletonThreadPool which doesnt have max_overflow
-            e = create_engine('sqlite://', max_overflow=5)
-            assert False
-        except TypeError:
-            assert True
+        self.assertRaises(TypeError, create_engine, 'mysql+mysqldb://', use_unicode=True, module=MockDBAPI())
 
-        e = create_engine('mysql://', module=MockDBAPI(), connect_args={'use_unicode':True}, convert_unicode=True)
+        # sqlite uses SingletonThreadPool which doesnt have max_overflow
+        self.assertRaises(TypeError, create_engine, 'sqlite://', max_overflow=5)
 
-        e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True)
-        try:
-            c = e.connect()
-            assert False
-        except tsa.exc.DBAPIError:
-            assert True
+        # raises DBAPIerror due to use_unicode not a sqlite arg
+        self.assertRaises(tsa.exc.DBAPIError, create_engine, 'sqlite://', connect_args={'use_unicode':True}, convert_unicode=True)
 
     def test_urlattr(self):
         """test the url attribute on ``Engine``."""
 
-        e = create_engine('mysql://scott:tiger@localhost/test', module=MockDBAPI())
+        e = create_engine('mysql://scott:tiger@localhost/test', module=MockDBAPI(), _initialize=False)
         u = url.make_url('mysql://scott:tiger@localhost/test')
-        e2 = create_engine(u, module=MockDBAPI())
+        e2 = create_engine(u, module=MockDBAPI(), _initialize=False)
         assert e.url.drivername == e2.url.drivername == 'mysql'
         assert e.url.username == e2.url.username == 'scott'
         assert e2.url is u
 
     def test_poolargs(self):
         """test that connection pool args make it thru"""
-        e = create_engine('postgres://', creator=None, pool_recycle=50, echo_pool=None, module=MockDBAPI())
+        e = create_engine('postgres://', creator=None, pool_recycle=50, echo_pool=None, module=MockDBAPI(), _initialize=False)
         assert e.pool._recycle == 50
 
         # these args work for QueuePool
@@ -213,13 +176,14 @@ class MockDBAPI(object):
     def __init__(self, **kwargs):
         self.kwargs = kwargs
         self.paramstyle = 'named'
-    def connect(self, **kwargs):
-        print kwargs, self.kwargs
+    def connect(self, *args, **kwargs):
         for k in self.kwargs:
             assert k in kwargs, "key %s not present in dictionary" % k
             assert kwargs[k]==self.kwargs[k], "value %s does not match %s" % (kwargs[k], self.kwargs[k])
         return MockConnection()
 class MockConnection(object):
+    def get_server_info(self):
+        return "5.0"
     def close(self):
         pass
     def cursor(self):
index 79e1041c102cedfbfd2c0422cb2b144246cd0a5c..057ebe6b8e9938aacb9ced49d95c4ec5c3c3818a 100644 (file)
@@ -120,12 +120,12 @@ class UserDefinedTest(TestBase, AssertsCompiledSQL):
 
         self.assert_compile(AddThingy(),
             "ADD SPECIAL PG THINGY",
-            dialect=create_engine('postgres://').dialect
+            dialect=create_engine('postgres://', _initialize=False).dialect
         )
 
         self.assert_compile(DropThingy(),
             "DROP THINGY",
-            dialect=create_engine('postgres://').dialect
+            dialect=create_engine('postgres://', _initialize=False).dialect
         )
         
         
index 1790b3cdeace7422da6a75c71956f092361a8f67..858213b77b5478843d41a49199c41eab095f80a9 100644 (file)
@@ -500,10 +500,10 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
     def test_match(self):
         for expr, check, dialect in [
             (table1.c.myid.match('somstr'), "mytable.myid MATCH ?", sqlite.SQLiteDialect()),
-            (table1.c.myid.match('somstr'), "MATCH (mytable.myid) AGAINST (%s IN BOOLEAN MODE)", mysql.MySQLDialect()),
-            (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", mssql.MSSQLDialect()),
-            (table1.c.myid.match('somstr'), "mytable.myid @@ to_tsquery(%(myid_1)s)", postgres.PGDialect()),
-            (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", oracle.OracleDialect()),            
+            (table1.c.myid.match('somstr'), "MATCH (mytable.myid) AGAINST (%s IN BOOLEAN MODE)", mysql.dialect()),
+            (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", mssql.dialect()),
+            (table1.c.myid.match('somstr'), "mytable.myid @@ to_tsquery(%(myid_1)s)", postgres.dialect()),
+            (table1.c.myid.match('somstr'), "CONTAINS (mytable.myid, :myid_1)", oracle.dialect()),            
         ]:
             self.assert_compile(expr, check, dialect=dialect)