]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- scale up for mysql, sqlite
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 27 Jul 2014 00:50:57 +0000 (20:50 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 27 Jul 2014 00:50:57 +0000 (20:50 -0400)
lib/sqlalchemy/testing/config.py
lib/sqlalchemy/testing/plugin/plugin_base.py
lib/sqlalchemy/testing/plugin/provision.py
lib/sqlalchemy/testing/plugin/pytestplugin.py
test/dialect/test_oracle.py
test/engine/test_reflection.py
test/sql/test_query.py

index 84344eb311151eeb4e98aa7f17f009a671dc4f1f..b24483bb72d77706a80e5fa19b9304e88b42fa6f 100644 (file)
@@ -12,7 +12,8 @@ db = None
 db_url = None
 db_opts = None
 file_config = None
-
+test_schema = None
+test_schema_2 = None
 _current = None
 
 
@@ -22,12 +23,14 @@ class Config(object):
         self.db_opts = db_opts
         self.options = options
         self.file_config = file_config
+        self.test_schema = "test_schema"
+        self.test_schema_2 = "test_schema_2"
 
     _stack = collections.deque()
     _configs = {}
 
     @classmethod
-    def register(cls, db, db_opts, options, file_config, namespace):
+    def register(cls, db, db_opts, options, file_config):
         """add a config as one of the global configs.
 
         If there are no configs set up yet, this config also
@@ -35,18 +38,18 @@ class Config(object):
         """
         cfg = Config(db, db_opts, options, file_config)
 
-        global _current
-        if not _current:
-            cls.set_as_current(cfg, namespace)
         cls._configs[cfg.db.name] = cfg
         cls._configs[(cfg.db.name, cfg.db.dialect)] = cfg
         cls._configs[cfg.db] = cfg
+        return cfg
 
     @classmethod
     def set_as_current(cls, config, namespace):
-        global db, _current, db_url
+        global db, _current, db_url, test_schema, test_schema_2
         _current = config
         db_url = config.db.url
+        test_schema = config.test_schema
+        test_schema_2 = config.test_schema_2
         namespace.db = db = config.db
 
     @classmethod
index f16a0828f57f9b217458fe469ef60a6cd69ed74b..095e3f3697e8535f40cab5a613fbddf165f1b8d1 100644 (file)
@@ -103,7 +103,7 @@ def setup_options(make_option):
 
 def configure_follower(follower_ident):
     global FOLLOWER_IDENT
-    FOLLOWER_IDENT = "test_%s" % follower_ident
+    FOLLOWER_IDENT = follower_ident
 
 
 def read_config():
@@ -221,18 +221,20 @@ def _engine_uri(options, file_config):
     if not db_urls:
         db_urls.append(file_config.get('db', 'default'))
 
+    from . import provision
+
     for db_url in db_urls:
-        if FOLLOWER_IDENT:
-            from sqlalchemy.engine import url
-            db_url = url.make_url(db_url)
-            db_url.database = FOLLOWER_IDENT
-        eng = engines.testing_engine(db_url, db_opts)
-        eng.connect().close()
-        config.Config.register(eng, db_opts, options, file_config, testing)
+        cfg = provision.setup_config(
+            db_url, db_opts, options, file_config, FOLLOWER_IDENT)
+
+        if not config._current:
+            cfg.set_as_current(cfg, testing)
 
     config.db_opts = db_opts
 
 
+
+
 @post
 def _engine_pool(options, file_config):
     if options.mockpool:
index e6790f877fd9c70f5f934e33568d9a1b509d39db..7c54cd643f1e64efc37553ce15a74b4b11a681f1 100644 (file)
@@ -1,11 +1,73 @@
 from sqlalchemy.engine import url as sa_url
+from sqlalchemy import text
+from sqlalchemy.util import compat
+from .. import config, engines
+import os
+
+
+class register(object):
+    def __init__(self):
+        self.fns = {}
+
+    @classmethod
+    def init(cls, fn):
+        return register().for_db("*")(fn)
+
+    def for_db(self, dbname):
+        def decorate(fn):
+            self.fns[dbname] = fn
+            return self
+        return decorate
+
+    def __call__(self, cfg, *arg):
+        if isinstance(cfg, compat.string_types):
+            url = sa_url.make_url(cfg)
+        elif isinstance(cfg, sa_url.URL):
+            url = cfg
+        else:
+            url = cfg.db.url
+        backend = url.get_backend_name()
+        if backend in self.fns:
+            return self.fns[backend](cfg, *arg)
+        else:
+            return self.fns['*'](cfg, *arg)
 
 
 def create_follower_db(follower_ident):
-    from .. import config, engines
 
-    follower_ident = "test_%s" % follower_ident
+    for cfg in _configs_for_db_operation():
+        url = cfg.db.url
+        backend = url.get_backend_name()
+        _create_db(cfg, cfg.db, follower_ident)
+
+        new_url = sa_url.make_url(str(url))
+
+        new_url.database = follower_ident
+
+
+def configure_follower(follower_ident):
+    for cfg in config.Config.all_configs():
+        _configure_follower(cfg, follower_ident)
+
+
+def setup_config(db_url, db_opts, options, file_config, follower_ident):
+    if follower_ident:
+        db_url = _follower_url_from_main(db_url, follower_ident)
+    eng = engines.testing_engine(db_url, db_opts)
+    eng.connect().close()
+    cfg = config.Config.register(eng, db_opts, options, file_config)
+    if follower_ident:
+        _configure_follower(cfg, follower_ident)
+    return cfg
+
+
+def drop_follower_db(follower_ident):
+    for cfg in _configs_for_db_operation():
+        url = cfg.db.url
+        _drop_db(cfg, cfg.db, follower_ident)
+
 
+def _configs_for_db_operation():
     hosts = set()
 
     for cfg in config.Config.all_configs():
@@ -19,47 +81,109 @@ def create_follower_db(follower_ident):
             url.username, url.host, url.database)
 
         if host_conf not in hosts:
-            if backend.startswith("postgresql"):
-                _pg_create_db(cfg.db, follower_ident)
-            elif backend.startswith("mysql"):
-                _mysql_create_db(cfg.db, follower_ident)
+            yield cfg
+            hosts.add(host_conf)
 
-            new_url = sa_url.make_url(str(url))
+    for cfg in config.Config.all_configs():
+        cfg.db.dispose()
 
-            new_url.database = follower_ident
-            eng = engines.testing_engine(new_url, cfg.db_opts)
 
-            if backend.startswith("postgresql"):
-                _pg_init_db(eng)
-            elif backend.startswith("mysql"):
-                _mysql_init_db(eng)
+@register.init
+def _create_db(cfg, eng, ident):
+    raise NotImplementedError("no DB creation routine for cfg: %s" % eng.url)
 
-            hosts.add(host_conf)
+
+@register.init
+def _drop_db(cfg, eng, ident):
+    raise NotImplementedError("no DB drop routine for cfg: %s" % eng.url)
+
+
+@register.init
+def _configure_follower(cfg, ident):
+    pass
+
+
+@register.init
+def _follower_url_from_main(url, ident):
+    url = sa_url.make_url(url)
+    url.database = ident
+    return url
+
+
+@_follower_url_from_main.for_db("sqlite")
+def _sqlite_follower_url_from_main(url, ident):
+    return sa_url.make_url("sqlite:///%s.db" % ident)
 
 
-def _pg_create_db(eng, ident):
+@_create_db.for_db("postgresql")
+def _pg_create_db(cfg, eng, ident):
     with eng.connect().execution_options(
             isolation_level="AUTOCOMMIT") as conn:
         try:
-            conn.execute("DROP DATABASE %s" % ident)
+            _pg_drop_db(cfg, conn, ident)
         except:
             pass
         currentdb = conn.scalar("select current_database()")
         conn.execute("CREATE DATABASE %s TEMPLATE %s" % (ident, currentdb))
 
 
-def _pg_init_db(eng):
+@_create_db.for_db("mysql")
+def _mysql_create_db(cfg, eng, ident):
+    with eng.connect() as conn:
+        try:
+            _mysql_drop_db(cfg, conn, ident)
+        except:
+            pass
+        conn.execute("CREATE DATABASE %s" % ident)
+        conn.execute("CREATE DATABASE %s_test_schema" % ident)
+        conn.execute("CREATE DATABASE %s_test_schema_2" % ident)
+
+
+@_configure_follower.for_db("mysql")
+def _mysql_configure_follower(config, ident):
+    config.test_schema = "%s_test_schema" % ident
+    config.test_schema_2 = "%s_test_schema_2" % ident
+
+
+@_create_db.for_db("sqlite")
+def _sqlite_create_db(cfg, eng, ident):
     pass
 
 
-def _mysql_create_db(eng, ident):
+@_drop_db.for_db("postgresql")
+def _pg_drop_db(cfg, eng, ident):
+    with eng.connect().execution_options(
+            isolation_level="AUTOCOMMIT") as conn:
+        conn.execute(
+            text(
+                "select pg_terminate_backend(pid) from pg_stat_activity "
+                "where usename=current_user and pid != pg_backend_pid() "
+                "and datname=:dname"
+            ), dname=ident)
+        conn.execute("DROP DATABASE %s" % ident)
+
+
+@_drop_db.for_db("sqlite")
+def _sqlite_drop_db(cfg, eng, ident):
+    os.remove("%s.db" % ident)
+
+
+@_drop_db.for_db("mysql")
+def _mysql_drop_db(cfg, eng, ident):
     with eng.connect() as conn:
+        try:
+            conn.execute("DROP DATABASE %s_test_schema" % ident)
+        except:
+            pass
+        try:
+            conn.execute("DROP DATABASE %s_test_schema_2" % ident)
+        except:
+            pass
         try:
             conn.execute("DROP DATABASE %s" % ident)
         except:
             pass
-        conn.execute("CREATE DATABASE %s" % ident)
 
 
-def _mysql_init_db(eng):
-    pass
+
+
index 7bef644d98d25433d7eaacfc9dd5bc285d36a01b..7671c800c33316639055eef79c3ded85ee7dcddb 100644 (file)
@@ -5,6 +5,12 @@ from . import plugin_base
 import collections
 import itertools
 
+try:
+    import xdist
+    has_xdist = True
+except ImportError:
+    has_xdist = False
+
 
 def pytest_addoption(parser):
     group = parser.getgroup("sqlalchemy")
@@ -37,15 +43,19 @@ def pytest_configure(config):
 
     plugin_base.post_begin()
 
-_follower_count = itertools.count(1)
+if has_xdist:
+    _follower_count = itertools.count(1)
 
+    def pytest_configure_node(node):
+        # the master for each node fills slaveinput dictionary
+        # which pytest-xdist will transfer to the subprocess
+        node.slaveinput["follower_ident"] = "test_%s" % next(_follower_count)
+        from . import provision
+        provision.create_follower_db(node.slaveinput["follower_ident"])
 
-def pytest_configure_node(node):
-    # the master for each node fills slaveinput dictionary
-    # which pytest-xdist will transfer to the subprocess
-    node.slaveinput["follower_ident"] = next(_follower_count)
-    from . import provision
-    provision.create_follower_db(node.slaveinput["follower_ident"])
+    def pytest_testnodedown(node, error):
+        from . import provision
+        provision.drop_follower_db(node.slaveinput["follower_ident"])
 
 
 def pytest_collection_modifyitems(session, config, items):
index f7c49c3d3df22bca1e215f8ca74281882bd32dcc..597a5dc407571a1562c98817005efeac1b753c73 100644 (file)
@@ -720,16 +720,16 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL):
         # don't really know how else to go here unless
         # we connect as the other user.
 
-        for stmt in """
-create table test_schema.parent(
+        for stmt in ("""
+create table %(test_schema)s.parent(
     id integer primary key,
     data varchar2(50)
 );
 
-create table test_schema.child(
+create table %(test_schema)s.child(
     id integer primary key,
     data varchar2(50),
-    parent_id integer references test_schema.parent(id)
+    parent_id integer references %(test_schema)s.parent(id)
 );
 
 create table local_table(
@@ -737,35 +737,35 @@ create table local_table(
     data varchar2(50)
 );
 
-create synonym test_schema.ptable for test_schema.parent;
-create synonym test_schema.ctable for test_schema.child;
+create synonym %(test_schema)s.ptable for %(test_schema)s.parent;
+create synonym %(test_schema)s.ctable for %(test_schema)s.child;
 
-create synonym test_schema_ptable for test_schema.parent;
+create synonym %(test_schema)s_ptable for %(test_schema)s.parent;
 
-create synonym test_schema.local_table for local_table;
+create synonym %(test_schema)s.local_table for local_table;
 
 -- can't make a ref from local schema to the
 -- remote schema's table without this,
 -- *and* cant give yourself a grant !
 -- so we give it to public.  ideas welcome.
-grant references on test_schema.parent to public;
-grant references on test_schema.child to public;
-""".split(";"):
+grant references on %(test_schema)s.parent to public;
+grant references on %(test_schema)s.child to public;
+""" % {"test_schema": testing.config.test_schema}).split(";"):
             if stmt.strip():
                 testing.db.execute(stmt)
 
     @classmethod
     def teardown_class(cls):
-        for stmt in """
-drop table test_schema.child;
-drop table test_schema.parent;
+        for stmt in ("""
+drop table %(test_schema)s.child;
+drop table %(test_schema)s.parent;
 drop table local_table;
-drop synonym test_schema.ctable;
-drop synonym test_schema.ptable;
-drop synonym test_schema_ptable;
-drop synonym test_schema.local_table;
+drop synonym %(test_schema)s.ctable;
+drop synonym %(test_schema)s.ptable;
+drop synonym %(test_schema)s_ptable;
+drop synonym %(test_schema)s.local_table;
 
-""".split(";"):
+""" % {"test_schema": testing.config.test_schema}).split(";"):
             if stmt.strip():
                 testing.db.execute(stmt)
 
@@ -798,11 +798,16 @@ drop synonym test_schema.local_table;
 
     def test_reflect_alt_synonym_owner_local_table(self):
         meta = MetaData(testing.db)
-        parent = Table('local_table', meta, autoload=True,
-                            oracle_resolve_synonyms=True, schema="test_schema")
-        self.assert_compile(parent.select(),
-                "SELECT test_schema.local_table.id, "
-                "test_schema.local_table.data FROM test_schema.local_table")
+        parent = Table(
+            'local_table', meta, autoload=True,
+            oracle_resolve_synonyms=True, schema=testing.config.test_schema)
+        self.assert_compile(
+            parent.select(),
+            "SELECT %(test_schema)s.local_table.id, "
+            "%(test_schema)s.local_table.data "
+            "FROM %(test_schema)s.local_table" %
+            {"test_schema": testing.config.test_schema}
+        )
         select([parent]).execute().fetchall()
 
     @testing.provide_metadata
@@ -820,31 +825,41 @@ drop synonym test_schema.local_table;
         child.insert().execute({'cid': 1, 'pid': 1})
         eq_(child.select().execute().fetchall(), [(1, 1)])
 
-
     def test_reflect_alt_owner_explicit(self):
         meta = MetaData(testing.db)
-        parent = Table('parent', meta, autoload=True, schema='test_schema')
-        child = Table('child', meta, autoload=True, schema='test_schema')
+        parent = Table(
+            'parent', meta, autoload=True,
+            schema=testing.config.test_schema)
+        child = Table(
+            'child', meta, autoload=True,
+            schema=testing.config.test_schema)
 
-        self.assert_compile(parent.join(child),
-                "test_schema.parent JOIN test_schema.child ON "
-                "test_schema.parent.id = test_schema.child.parent_id")
+        self.assert_compile(
+            parent.join(child),
+            "%(test_schema)s.parent JOIN %(test_schema)s.child ON "
+            "%(test_schema)s.parent.id = %(test_schema)s.child.parent_id" % {
+                "test_schema": testing.config.test_schema
+            })
         select([parent, child]).\
-                select_from(parent.join(child)).\
-                execute().fetchall()
+            select_from(parent.join(child)).\
+            execute().fetchall()
 
     def test_reflect_local_to_remote(self):
-        testing.db.execute('CREATE TABLE localtable (id INTEGER '
-                           'PRIMARY KEY, parent_id INTEGER REFERENCES '
-                           'test_schema.parent(id))')
+        testing.db.execute(
+            'CREATE TABLE localtable (id INTEGER '
+            'PRIMARY KEY, parent_id INTEGER REFERENCES '
+            '%(test_schema)s.parent(id))' % {
+                "test_schema": testing.config.test_schema})
         try:
             meta = MetaData(testing.db)
             lcl = Table('localtable', meta, autoload=True)
-            parent = meta.tables['test_schema.parent']
+            parent = meta.tables['%s.parent' % testing.config.test_schema]
             self.assert_compile(parent.join(lcl),
-                                'test_schema.parent JOIN localtable ON '
-                                'test_schema.parent.id = '
-                                'localtable.parent_id')
+                                '%(test_schema)s.parent JOIN localtable ON '
+                                '%(test_schema)s.parent.id = '
+                                'localtable.parent_id' % {
+                                    "test_schema": testing.config.test_schema}
+                                )
             select([parent,
                    lcl]).select_from(parent.join(lcl)).execute().fetchall()
         finally:
@@ -852,30 +867,36 @@ drop synonym test_schema.local_table;
 
     def test_reflect_alt_owner_implicit(self):
         meta = MetaData(testing.db)
-        parent = Table('parent', meta, autoload=True,
-                       schema='test_schema')
-        child = Table('child', meta, autoload=True, schema='test_schema'
-                      )
-        self.assert_compile(parent.join(child),
-                            'test_schema.parent JOIN test_schema.child '
-                            'ON test_schema.parent.id = '
-                            'test_schema.child.parent_id')
+        parent = Table(
+            'parent', meta, autoload=True,
+            schema=testing.config.test_schema)
+        child = Table(
+            'child', meta, autoload=True,
+            schema=testing.config.test_schema)
+        self.assert_compile(
+            parent.join(child),
+            '%(test_schema)s.parent JOIN %(test_schema)s.child '
+            'ON %(test_schema)s.parent.id = '
+            '%(test_schema)s.child.parent_id' % {
+                "test_schema": testing.config.test_schema})
         select([parent,
                child]).select_from(parent.join(child)).execute().fetchall()
 
     def test_reflect_alt_owner_synonyms(self):
         testing.db.execute('CREATE TABLE localtable (id INTEGER '
                            'PRIMARY KEY, parent_id INTEGER REFERENCES '
-                           'test_schema.ptable(id))')
+                           '%s.ptable(id))' % testing.config.test_schema)
         try:
             meta = MetaData(testing.db)
             lcl = Table('localtable', meta, autoload=True,
                         oracle_resolve_synonyms=True)
-            parent = meta.tables['test_schema.ptable']
-            self.assert_compile(parent.join(lcl),
-                                'test_schema.ptable JOIN localtable ON '
-                                'test_schema.ptable.id = '
-                                'localtable.parent_id')
+            parent = meta.tables['%s.ptable' % testing.config.test_schema]
+            self.assert_compile(
+                parent.join(lcl),
+                '%(test_schema)s.ptable JOIN localtable ON '
+                '%(test_schema)s.ptable.id = '
+                'localtable.parent_id' % {
+                    "test_schema": testing.config.test_schema})
             select([parent,
                    lcl]).select_from(parent.join(lcl)).execute().fetchall()
         finally:
@@ -884,18 +905,22 @@ drop synonym test_schema.local_table;
     def test_reflect_remote_synonyms(self):
         meta = MetaData(testing.db)
         parent = Table('ptable', meta, autoload=True,
-                       schema='test_schema',
+                       schema=testing.config.test_schema,
                        oracle_resolve_synonyms=True)
         child = Table('ctable', meta, autoload=True,
-                      schema='test_schema',
+                      schema=testing.config.test_schema,
                       oracle_resolve_synonyms=True)
-        self.assert_compile(parent.join(child),
-                            'test_schema.ptable JOIN '
-                            'test_schema.ctable ON test_schema.ptable.i'
-                            'd = test_schema.ctable.parent_id')
+        self.assert_compile(
+            parent.join(child),
+            '%(test_schema)s.ptable JOIN '
+            '%(test_schema)s.ctable '
+            'ON %(test_schema)s.ptable.id = '
+            '%(test_schema)s.ctable.parent_id' % {
+                "test_schema": testing.config.test_schema})
         select([parent,
                child]).select_from(parent.join(child)).execute().fetchall()
 
+
 class ConstraintTest(fixtures.TablesTest):
 
     __only_on__ = 'oracle'
index 1db37851d40614c8a66257276f430a515b09822a..1ddae6b40a04decb87fd468780b27b05e97ff5f7 100644 (file)
@@ -1237,8 +1237,10 @@ class SchemaTest(fixtures.TestBase):
     @testing.requires.schemas
     @testing.requires.cross_schema_fk_reflection
     def test_has_schema(self):
-        eq_(testing.db.dialect.has_schema(testing.db, 'test_schema'), True)
-        eq_(testing.db.dialect.has_schema(testing.db, 'sa_fake_schema_123'), False)
+        eq_(testing.db.dialect.has_schema(testing.db,
+            testing.config.test_schema), True)
+        eq_(testing.db.dialect.has_schema(testing.db,
+            'sa_fake_schema_123'), False)
 
     @testing.requires.schemas
     @testing.fails_on('sqlite', 'FIXME: unknown')
@@ -1320,14 +1322,17 @@ class SchemaTest(fixtures.TestBase):
     @testing.provide_metadata
     def test_metadata_reflect_schema(self):
         metadata = self.metadata
-        createTables(metadata, "test_schema")
+        createTables(metadata, testing.config.test_schema)
         metadata.create_all()
-        m2 = MetaData(schema="test_schema", bind=testing.db)
+        m2 = MetaData(schema=testing.config.test_schema, bind=testing.db)
         m2.reflect()
         eq_(
             set(m2.tables),
-            set(['test_schema.dingalings', 'test_schema.users',
-                'test_schema.email_addresses'])
+            set([
+                '%s.dingalings' % testing.config.test_schema,
+                '%s.users' % testing.config.test_schema,
+                '%s.email_addresses' % testing.config.test_schema
+                ])
         )
 
     @testing.requires.schemas
@@ -1339,16 +1344,16 @@ class SchemaTest(fixtures.TestBase):
 
         t2 = Table('t', self.metadata,
             Column('id1', sa.ForeignKey('t.id')),
-            schema="test_schema"
+            schema=testing.config.test_schema
         )
 
         self.metadata.create_all()
         m2 = MetaData()
-        m2.reflect(testing.db, schema="test_schema")
+        m2.reflect(testing.db, schema=testing.config.test_schema)
 
         m3 = MetaData()
         m3.reflect(testing.db)
-        m3.reflect(testing.db, schema="test_schema")
+        m3.reflect(testing.db, schema=testing.config.test_schema)
 
         eq_(
             set((t.name, t.schema) for t in m2.tables.values()),
index 039e8d7e53568b4cbfc59501062a0012b6f59609..23f6da029488429f3dec16f9b1fe5e1d3a34858d 100644 (file)
@@ -1625,7 +1625,7 @@ class KeyTargetingTest(fixtures.TablesTest):
                 'wschema', metadata,
                 Column("a", CHAR(2), key="b"),
                 Column("c", CHAR(2), key="q"),
-                schema="test_schema"
+                schema=testing.config.test_schema
             )
 
     @classmethod
@@ -1637,12 +1637,12 @@ class KeyTargetingTest(fixtures.TablesTest):
         cls.tables.content.insert().execute(type="t1")
 
         if testing.requires.schemas.enabled:
-            cls.tables['test_schema.wschema'].insert().execute(
+            cls.tables['%s.wschema' % testing.config.test_schema].insert().execute(
                 dict(b="a1", q="c1"))
 
     @testing.requires.schemas
     def test_keyed_accessor_wschema(self):
-        keyed1 = self.tables['test_schema.wschema']
+        keyed1 = self.tables['%s.wschema' % testing.config.test_schema]
         row = testing.db.execute(keyed1.select()).first()
 
         eq_(row.b, "a1")