]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Rearranged engine initialization, its now easy to make ad-hoc testing engines that...
authorJason Kirtland <jek@discorporate.us>
Fri, 3 Aug 2007 20:08:26 +0000 (20:08 +0000)
committerJason Kirtland <jek@discorporate.us>
Fri, 3 Aug 2007 20:08:26 +0000 (20:08 +0000)
Promoted the 'utf8 bind' logic for tests needing utf8 connections into testlib
Added a pause before issuing DROPs to rid the testing db of clutter

test/engine/bind.py
test/sql/unicode.py
test/testlib/__init__.py
test/testlib/config.py
test/testlib/engines.py [new file with mode: 0644]

index 6a0c78f5782edb4a46c823d827dc109a8814801b..2d96683b9104d0eef8b2d6a1d5a4ea278cb0bfbf 100644 (file)
@@ -5,6 +5,7 @@ import testbase
 from sqlalchemy import *
 from testlib import *
 
+
 class BindTest(PersistTest):
     def test_create_drop_explicit(self):
         metadata = MetaData()
index 0ffa7c34d60e7f942c6942283dfc33f1a617540a..eebfe970fbe7e0291480529a760b48a5b26a767a 100644 (file)
@@ -5,13 +5,14 @@ import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import mapper, relation, create_session, eagerload
 from testlib import *
+from testlib.engines import utf8_engine
 
 
 class UnicodeSchemaTest(PersistTest):
     def setUpAll(self):
         global unicode_bind, metadata, t1, t2
 
-        unicode_bind = self._unicode_bind()
+        unicode_bind = utf8_engine()
 
         metadata = MetaData(unicode_bind)
         t1 = Table('unitable1', metadata,
@@ -34,16 +35,6 @@ class UnicodeSchemaTest(PersistTest):
         global unicode_bind
         metadata.drop_all()
         del unicode_bind
-
-    def _unicode_bind(self):
-        if testbase.db.name != 'mysql':
-            return testbase.db
-        else:
-            from sqlalchemy.databases import mysql
-            engine = create_engine(testbase.db.url,
-                                   connect_args={'charset': 'utf8',
-                                                 'use_unicode': False})
-            return engine
         
     def test_insert(self):
         t1.insert().execute({u'méil':1, u'\u6e2c\u8a66':5})
index ff5c4c125e28e46d9bdd7462772f0ab9ae2d5625..3967863b8444e52b9aecf2278ffd201ffa44c5e4 100644 (file)
@@ -8,4 +8,9 @@ from testlib.schema import Table, Column
 import testlib.testing as testing
 from testlib.testing import PersistTest, AssertMixin, ORMTest
 import testlib.profiling
+import testlib.engines
 
+
+__all__ = ('testing',
+           'Table', 'Column',
+           'PersistTest', 'AssertMixin', 'ORMTest')
index db58718ceada19153aeaf9309e982ca9251f0be2..23f86a70ed4ca3043e17e59cc268806444b834a2 100644 (file)
@@ -1,10 +1,12 @@
 import testbase
-import optparse, os, sys, ConfigParser, StringIO
+import optparse, os, sys, ConfigParser, StringIO, time
 logging, require = None, None
 
+
 __all__ = 'parser', 'configure', 'options',
 
-db, db_uri, db_type, db_label = None, None, None, None
+db = None
+db_label, db_url, db_opts = None, None, {}
 
 options = None
 file_config = None
@@ -78,6 +80,13 @@ def _list_dbs(*args):
         print "%20s\t%s" % (macro, file_config.get('db', macro))
     sys.exit(0)
 
+def _server_side_cursors(options, opt_str, value, parser):
+    db_opts['server_side_cursors'] = True
+
+def _engine_strategy(options, opt_str, value, parser):
+    if value:
+        db_opts['strategy'] = value
+
 opt = parser.add_option
 opt("--verbose", action="store_true", dest="verbose",
     help="enable stdout echoing/printing")
@@ -96,12 +105,13 @@ opt("--dburi", action="store", dest="dburi",
     help="Database uri (overrides --db)")
 opt("--mockpool", action="store_true", dest="mockpool",
     help="Use mock pool (asserts only one connection used)")
-opt("--enginestrategy", action="store", dest="enginestrategy", default=None,
+opt("--enginestrategy", action="callback", type="string",
+    callback=_engine_strategy,
     help="Engine strategy (plain or threadlocal, defaults toplain)")
 opt("--reversetop", action="store_true", dest="reversetop", default=False,
     help="Reverse the collection ordering for topological sorts (helps "
           "reveal dependency issues)")
-opt("--serverside", action="store_true", dest="serverside",
+opt("--serverside", action="callback", callback=_server_side_cursors,
     help="Turn on server side cursors for PG")
 opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
     help="Use the specified MySQL storage engine for all tables, default is "
@@ -132,23 +142,25 @@ class _ordered_map(object):
         for key in self._keys:
             yield self._data[key]
     
+# at one point in refactoring, modules were injecting into the config
+# process.  this could probably just become a list now.
 post_configure = _ordered_map()
 
 def _engine_uri(options, file_config):
-    global db_label, db_uri
+    global db_label, db_url
     db_label = 'sqlite'
     if options.dburi:
-        db_uri = options.dburi
-        db_label = db_uri[:db_uri.index(':')]
+        db_url = options.dburi
+        db_label = db_url[:db_url.index(':')]
     elif options.db:
         db_label = options.db
-        db_uri = None
+        db_url = None
 
-    if db_uri is None:
+    if db_url is None:
         if db_label not in file_config.options('db'):
             raise RuntimeError(
                 "Unknown engine.  Specify --dbs for known engines.")
-        db_uri = file_config.get('db', db_label)
+        db_url = file_config.get('db', db_label)
 post_configure['engine_uri'] = _engine_uri
 
 def _require(options, file_config):
@@ -177,37 +189,36 @@ def _require(options, file_config):
             pkg_resources.require(requirement)
 post_configure['require'] = _require
 
-def _create_testing_engine(options, file_config):
-    from sqlalchemy import engine, schema
-    global db, db_type
-    engine_opts = {}
-    if options.serverside:
-        engine_opts['server_side_cursors'] = True
-    
-    if options.enginestrategy is not None:
-        engine_opts['strategy'] = options.enginestrategy    
-
+def _engine_pool(options, file_config):
     if options.mockpool:
-        db = engine.create_engine(db_uri, poolclass=pool.AssertionPool,
-                                  **engine_opts)
-    else:
-        db = engine.create_engine(db_uri, **engine_opts)
-    db_type = db.name
-
-    print "Dropping existing tables in database: " + db_uri
-    md = schema.MetaData(db, reflect=True)
-    md.drop_all()
-    
-    # decorate the dialect's create_execution_context() method
-    # to produce a wrapper
-    from testlib.testing import ExecutionContextWrapper
-
-    create_context = db.dialect.create_execution_context
-    def create_exec_context(*args, **kwargs):
-        return ExecutionContextWrapper(create_context(*args, **kwargs))
-    db.dialect.create_execution_context = create_exec_context
+        from sqlalchemy import pool
+        db_opts['poolclass'] = pool.AssertionPool
+post_configure['engine_pool'] = _engine_pool
+
+def _create_testing_engine(options, file_config):
+    from testlib import engines
+    global db
+    db = engines.testing_engine(db_url, db_opts)
 post_configure['create_engine'] = _create_testing_engine
 
+def _prep_testing_database(options, file_config):
+    from sqlalchemy import schema
+
+    # also create alt schemas etc. here?
+    existing = db.table_names()
+    if existing:
+        print "Dropping existing tables in database: " + db_url
+        try:
+            print "Tables: %s" % ', '.join(existing)
+        except:
+            pass
+        print "Abort within 5 seconds..."
+        time.sleep(5)
+
+        md = schema.MetaData(db, reflect=True)
+        md.drop_all()
+post_configure['prep_db'] = _prep_testing_database
+
 def _set_table_options(options, file_config):
     import testlib.schema
     
diff --git a/test/testlib/engines.py b/test/testlib/engines.py
new file mode 100644 (file)
index 0000000..da33749
--- /dev/null
@@ -0,0 +1,33 @@
+from testlib import config
+
+
+def testing_engine(url=None, options=None):
+    """Produce an engine configured by --options with optional overrides."""
+    
+    from sqlalchemy import create_engine
+    from testlib.testing import ExecutionContextWrapper
+
+    url = url or config.db_url
+    options = options or config.db_opts
+
+    engine = create_engine(url, **options)
+
+    create_context = engine.dialect.create_execution_context
+    def create_exec_context(*args, **kwargs):
+        return ExecutionContextWrapper(create_context(*args, **kwargs))
+    engine.dialect.create_execution_context = create_exec_context
+    return engine
+
+def utf8_engine(url=None, options=None):
+    """Hook for dialects or drivers that don't handle utf8 by default."""
+
+    from sqlalchemy.engine import url as engine_url
+
+    if config.db.name == 'mysql':
+        url = url or config.db_url
+        url = engine_url.make_url(url)
+        url.query['charset'] = 'utf8'
+        url.query['use_unicode'] = '0'
+        url = str(url)
+
+    return testing_engine(url, options)