]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The :func:`.engine_from_config` function has been improved so that
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 7 Dec 2013 23:38:15 +0000 (18:38 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 7 Dec 2013 23:38:15 +0000 (18:38 -0500)
we will be able to parse dialect-specific arguments from string
configuration dictionaries.  Dialect classes can now provide their
own list of parameter types and string-conversion routines.
The feature is not yet used by the built-in dialects, however.
[ticket:2875]

doc/build/changelog/changelog_09.rst
lib/sqlalchemy/engine/__init__.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/engine/util.py
test/engine/test_parseconnect.py

index 14688fbfdd1de8e2b5c5f96686b8a3af68c343c2..4656fcf331e992dc582a51e756defc8e33bafd22 100644 (file)
 .. changelog::
     :version: 0.9.0b2
 
+    .. change::
+        :tags: feature, engine
+        :tickets: 2875
+
+        The :func:`.engine_from_config` function has been improved so that
+        we will be able to parse dialect-specific arguments from string
+        configuration dictionaries.  Dialect classes can now provide their
+        own list of parameter types and string-conversion routines.
+        The feature is not yet used by the built-in dialects, however.
+
     .. change::
         :tags: bug, sql
         :tickets: 2879
index 16d214140c447699e486dfc32cd0f0bd9d744296..128c4e8f6a3c171fa5304151eff86ae6d523531f 100644 (file)
@@ -348,10 +348,13 @@ def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
     arguments.
     """
 
-    opts = util._coerce_config(configuration, prefix)
-    opts.update(kwargs)
-    url = opts.pop('url')
-    return create_engine(url, **opts)
+    options = dict((key[len(prefix):], configuration[key])
+                   for key in configuration
+                   if key.startswith(prefix))
+    options['_coerce_config'] = True
+    options.update(kwargs)
+    url = options.pop('url')
+    return create_engine(url, **options)
 
 
 __all__ = (
index 8fb7c3bb8d24bc2ead86279f3e475e0ddfb1a64f..dec6f2cdfa78cf121f0dc2d6b03f377d904c52d0 100644 (file)
@@ -59,6 +59,18 @@ class DefaultDialect(interfaces.Dialect):
 
     supports_simple_order_by_label = True
 
+    engine_config_types = util.immutabledict([
+        ('convert_unicode', util.bool_or_str('force')),
+        ('pool_timeout', int),
+        ('echo', util.bool_or_str('debug')),
+        ('echo_pool', util.bool_or_str('debug')),
+        ('pool_recycle', int),
+        ('pool_size', int),
+        ('max_overflow', int),
+        ('pool_threadlocal', bool),
+        ('use_native_unicode', bool),
+    ])
+
     # if the NUMERIC type
     # returns decimal.Decimal.
     # *not* the FLOAT type however.
index 5a3b2c5af86dcf2e234f6e1c784ca5b3afbcfff0..4befe58fdac1876b24f09f9cc0e7c071d18f2db3 100644 (file)
@@ -49,18 +49,27 @@ class DefaultEngineStrategy(EngineStrategy):
 
         dialect_cls = u.get_dialect()
 
+        if kwargs.pop('_coerce_config', False):
+            def pop_kwarg(key, default=None):
+                value = kwargs.pop(key, default)
+                if key in dialect_cls.engine_config_types:
+                    value = dialect_cls.engine_config_types[key](value)
+                return value
+        else:
+            pop_kwarg = kwargs.pop
+
         dialect_args = {}
         # consume dialect arguments from kwargs
         for k in util.get_cls_kwargs(dialect_cls):
             if k in kwargs:
-                dialect_args[k] = kwargs.pop(k)
+                dialect_args[k] = pop_kwarg(k)
 
         dbapi = kwargs.pop('module', None)
         if dbapi is None:
             dbapi_args = {}
             for k in util.get_func_kwargs(dialect_cls.dbapi):
                 if k in kwargs:
-                    dbapi_args[k] = kwargs.pop(k)
+                    dbapi_args[k] = pop_kwarg(k)
             dbapi = dialect_cls.dbapi(**dbapi_args)
 
         dialect_args['dbapi'] = dbapi
@@ -70,10 +79,10 @@ class DefaultEngineStrategy(EngineStrategy):
 
         # assemble connection arguments
         (cargs, cparams) = dialect.create_connect_args(u)
-        cparams.update(kwargs.pop('connect_args', {}))
+        cparams.update(pop_kwarg('connect_args', {}))
 
         # look for existing pool or create
-        pool = kwargs.pop('pool', None)
+        pool = pop_kwarg('pool', None)
         if pool is None:
             def connect():
                 try:
@@ -87,9 +96,9 @@ class DefaultEngineStrategy(EngineStrategy):
                         )
                     )
 
-            creator = kwargs.pop('creator', connect)
+            creator = pop_kwarg('creator', connect)
 
-            poolclass = kwargs.pop('poolclass', None)
+            poolclass = pop_kwarg('poolclass', None)
             if poolclass is None:
                 poolclass = dialect_cls.get_pool_class(u)
             pool_args = {}
@@ -106,7 +115,7 @@ class DefaultEngineStrategy(EngineStrategy):
             for k in util.get_cls_kwargs(poolclass):
                 tk = translate.get(k, k)
                 if tk in kwargs:
-                    pool_args[k] = kwargs.pop(tk)
+                    pool_args[k] = pop_kwarg(tk)
             pool = poolclass(creator, **pool_args)
         else:
             if isinstance(pool, poollib._DBProxy):
@@ -119,7 +128,7 @@ class DefaultEngineStrategy(EngineStrategy):
         engine_args = {}
         for k in util.get_cls_kwargs(engineclass):
             if k in kwargs:
-                engine_args[k] = kwargs.pop(k)
+                engine_args[k] = pop_kwarg(k)
 
         _initialize = kwargs.pop('_initialize', True)
 
index e5645275181ab63590c1329b999b1604be0f333a..f4a2b0cc7c04aad1b6c3aeb602c984cc3b641b83 100644 (file)
@@ -6,28 +6,6 @@
 
 from .. import util
 
-
-def _coerce_config(configuration, prefix):
-    """Convert configuration values to expected types."""
-
-    options = dict((key[len(prefix):], configuration[key])
-                   for key in configuration
-                   if key.startswith(prefix))
-    for option, type_ in (
-        ('convert_unicode', util.bool_or_str('force')),
-        ('pool_timeout', int),
-        ('echo', util.bool_or_str('debug')),
-        ('echo_pool', util.bool_or_str('debug')),
-        ('pool_recycle', int),
-        ('pool_size', int),
-        ('max_overflow', int),
-        ('pool_threadlocal', bool),
-        ('use_native_unicode', bool),
-    ):
-        util.coerce_kw_type(options, option, type_)
-    return options
-
-
 def connection_memoize(key):
     """Decorator, memoize a function in a connection.info stash.
 
index c4d8b8edc60a8e932b4180687603161b31becb13..391b921445ecdf924b9e85312b2cc01370e982bd 100644 (file)
@@ -2,12 +2,11 @@ from sqlalchemy.testing import assert_raises, eq_, assert_raises_message
 from sqlalchemy.util.compat import configparser, StringIO
 import sqlalchemy.engine.url as url
 from sqlalchemy import create_engine, engine_from_config, exc, pool
-from sqlalchemy.engine.util import _coerce_config
 from sqlalchemy.engine.default import DefaultDialect
 import sqlalchemy as tsa
 from sqlalchemy.testing import fixtures
 from sqlalchemy import testing
-from sqlalchemy.testing.mock import Mock
+from sqlalchemy.testing.mock import Mock, MagicMock, patch
 
 
 class ParseConnectTest(fixtures.TestBase):
@@ -110,50 +109,6 @@ class CreateEngineTest(fixtures.TestBase):
                           module=dbapi, _initialize=False)
         c = e.connect()
 
-    def test_coerce_config(self):
-        raw = r"""
-[prefixed]
-sqlalchemy.url=postgresql://scott:tiger@somehost/test?fooz=somevalue
-sqlalchemy.convert_unicode=0
-sqlalchemy.echo=false
-sqlalchemy.echo_pool=1
-sqlalchemy.max_overflow=2
-sqlalchemy.pool_recycle=50
-sqlalchemy.pool_size=2
-sqlalchemy.pool_threadlocal=1
-sqlalchemy.pool_timeout=10
-[plain]
-url=postgresql://scott:tiger@somehost/test?fooz=somevalue
-convert_unicode=0
-echo=0
-echo_pool=1
-max_overflow=2
-pool_recycle=50
-pool_size=2
-pool_threadlocal=1
-pool_timeout=10
-"""
-        ini = configparser.ConfigParser()
-        ini.readfp(StringIO(raw))
-
-        expected = {
-            'url': 'postgresql://scott:tiger@somehost/test?fooz=somevalue',
-            'convert_unicode': 0,
-            'echo': False,
-            'echo_pool': True,
-            'max_overflow': 2,
-            'pool_recycle': 50,
-            'pool_size': 2,
-            'pool_threadlocal': True,
-            'pool_timeout': 10,
-            }
-
-        prefixed = dict(ini.items('prefixed'))
-        self.assert_(_coerce_config(prefixed, 'sqlalchemy.')
-                     == expected)
-
-        plain = dict(ini.items('plain'))
-        self.assert_(_coerce_config(plain, '') == expected)
 
     def test_engine_from_config(self):
         dbapi = mock_dbapi
@@ -170,19 +125,35 @@ pool_timeout=10
                             'z=somevalue')
         assert e.echo is True
 
-        for param, values in [
-            ('convert_unicode', ('true', 'false', 'force')),
-            ('echo', ('true', 'false', 'debug')),
-            ('echo_pool', ('true', 'false', 'debug')),
-            ('use_native_unicode', ('true', 'false')),
-        ]:
-            for value in values:
-                config = {
-                        'sqlalchemy.url': 'postgresql://scott:tiger@somehost/test',
-                        'sqlalchemy.%s' % param : value
-                }
-                cfg = _coerce_config(config, 'sqlalchemy.')
-                assert cfg[param] == {'true':True, 'false':False}.get(value, value)
+
+    def test_engine_from_config_custom(self):
+        from sqlalchemy import util
+        from sqlalchemy.dialects import registry
+        tokens = __name__.split(".")
+
+        class MyDialect(MockDialect):
+            engine_config_types = {
+                "foobar": int,
+                "bathoho": util.bool_or_str('force')
+            }
+
+            def __init__(self, foobar=None, bathoho=None, **kw):
+                self.foobar = foobar
+                self.bathoho = bathoho
+
+        global dialect
+        dialect = MyDialect
+        registry.register("mockdialect.barb",
+                    ".".join(tokens[0:-1]), tokens[-1])
+
+        config = {
+            "sqlalchemy.url": "mockdialect+barb://",
+            "sqlalchemy.foobar": "5",
+            "sqlalchemy.bathoho": "false"
+        }
+        e = engine_from_config(config, _initialize=False)
+        eq_(e.dialect.foobar, 5)
+        eq_(e.dialect.bathoho, False)
 
 
     def test_custom(self):
@@ -417,7 +388,7 @@ def MockDBAPI(**assert_kwargs):
             )
         return connection
 
-    return Mock(
+    return MagicMock(
                 sqlite_version_info=(99, 9, 9,),
                 version_info=(99, 9, 9,),
                 sqlite_version='99.9.9',