]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- create_engine() reworked to be strict about incoming **kwargs. all keyword
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 12 Nov 2006 20:50:51 +0000 (20:50 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 12 Nov 2006 20:50:51 +0000 (20:50 +0000)
arguments must be consumed by one of the dialect, connection pool, and engine
constructors, else a TypeError is thrown which describes the full set of
invalid kwargs in relation to the selected dialect/pool/engine configuration.

CHANGES
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/util.py
test/engine/parseconnect.py
test/testbase.py

diff --git a/CHANGES b/CHANGES
index 1949a7f3706eb039f976f6c5d7943a71ad3869f5..82c60160390a5f01121533a77347d06a0f643869 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -30,6 +30,10 @@ the target class
 for [ticket:362]
 - "delete-orphan" for a certain type can be set on more than one parent class;
 the instance is an "orphan" only if its not attached to *any* of those parents
+- create_engine() reworked to be strict about incoming **kwargs.  all keyword
+arguments must be consumed by one of the dialect, connection pool, and engine
+constructors, else a TypeError is thrown which describes the full set of
+invalid kwargs in relation to the selected dialect/pool/engine configuration.
 
 0.3.0
 - General:
index 07a88659be34da70f1ae1d945874a70f378a264e..205e5aa02c78ef9e732aa754d12d517b167b4c81 100644 (file)
@@ -396,7 +396,7 @@ class Engine(sql.Executor, Connectable):
     Connects a ConnectionProvider, a Dialect and a CompilerFactory together to 
     provide a default implementation of SchemaEngine.
     """
-    def __init__(self, connection_provider, dialect, echo=None, **kwargs):
+    def __init__(self, connection_provider, dialect, echo=None):
         self.connection_provider = connection_provider
         self.dialect=dialect
         self.echo = echo
index 02d3e4608aa2bcec21b0c837a908c86f38ec90b0..4af539e784021a44352d0b87bfc10d4b5d246e7d 100644 (file)
@@ -6,7 +6,6 @@
 
 
 from sqlalchemy import schema, exceptions, util, sql, types
-from sqlalchemy import  pool as poollib
 import StringIO, sys, re
 from sqlalchemy.engine import base
 
@@ -14,30 +13,8 @@ from sqlalchemy.engine import base
 
 
 class PoolConnectionProvider(base.ConnectionProvider):
-    def __init__(self, dialect, url, poolclass=None, pool=None, **kwargs):
-        (cargs, cparams) = dialect.create_connect_args(url)
-        cparams.update(kwargs.pop('connect_args', {}))
-        
-        if pool is None:
-            kwargs.setdefault('echo', False)
-            kwargs.setdefault('use_threadlocal',True)
-            if poolclass is None:
-                poolclass = poollib.QueuePool
-            dbapi = dialect.dbapi()
-            if dbapi is None:
-                raise exceptions.InvalidRequestError("Cant get DBAPI module for dialect '%s'" % dialect)
-            def connect():
-                try:
-                    return dbapi.connect(*cargs, **cparams)
-                except Exception, e:
-                    raise exceptions.DBAPIError("Connection failed", e)
-            creator = kwargs.pop('creator', connect)
-            self._pool = poolclass(creator, **kwargs)
-        else:
-            if isinstance(pool, poollib.DBProxy):
-                self._pool = pool.get_pool(*cargs, **cparams)
-            else:
-                self._pool = pool
+    def __init__(self, pool):
+        self._pool = pool
     def get_connection(self):
         return self._pool.connect()
     def dispose(self):
index fe30aeb8d746465c1aca273a254f0e42354c25e8..d48412160d62f67c65e6fabc07a21b26b6a94eea 100644 (file)
@@ -6,6 +6,8 @@ this can be accomplished via a mod; see the sqlalchemy/mods package for details.
 
 
 from sqlalchemy.engine import base, default, threadlocal, url
+from sqlalchemy import util, exceptions
+from sqlalchemy import pool as poollib
 
 strategies = {}
 
@@ -22,29 +24,74 @@ class EngineStrategy(object):
         raise NotImplementedError()
 
 class DefaultEngineStrategy(EngineStrategy):
-    def create(self, name_or_url, **kwargs):    
+    def create(self, name_or_url, **kwargs):
+        # create url.URL object
         u = url.make_url(name_or_url)
+        
+        # get module from sqlalchemy.databases
         module = u.get_module()
 
-        dialect = module.dialect(**kwargs)
+        dialect_args = {}
+        # consume dialect arguments from kwargs
+        for k in util.get_cls_kwargs(module.dialect):
+            if k in kwargs:
+                dialect_args[k] = kwargs.pop(k)
+                
+        # create dialect
+        dialect = module.dialect(**dialect_args)
 
-        poolargs = {}
-        for key in (('echo_pool', 'echo'), ('pool_size', 'pool_size'), ('max_overflow', 'max_overflow'), ('poolclass', 'poolclass'), ('pool_timeout','timeout'), ('pool', 'pool'), ('pool_recycle','recycle'),('connect_args', 'connect_args'), ('creator', 'creator')):
-           if kwargs.has_key(key[0]):
-               poolargs[key[1]] = kwargs[key[0]]
-        poolclass = getattr(module, 'poolclass', None)
-        if poolclass is not None:
-           poolargs.setdefault('poolclass', poolclass)
-        poolargs['use_threadlocal'] = self.pool_threadlocal()
-        provider = self.get_pool_provider(dialect, u, **poolargs)
+        # assemble connection arguments
+        (cargs, cparams) = dialect.create_connect_args(u)
+        cparams.update(kwargs.pop('connect_args', {}))
 
-        return self.get_engine(provider, dialect, **kwargs)
+        # look for existing pool or create
+        pool = kwargs.pop('pool', None)
+        if pool is None:
+            dbapi = kwargs.pop('module', dialect.dbapi())
+            if dbapi is None:
+                raise exceptions.InvalidRequestError("Cant get DBAPI module for dialect '%s'" % dialect)
+            def connect():
+                try:
+                    return dbapi.connect(*cargs, **cparams)
+                except Exception, e:
+                    raise exceptions.DBAPIError("Connection failed", e)
+            creator = kwargs.pop('creator', connect)
+
+            poolclass = kwargs.pop('poolclass', getattr(module, 'poolclass', poollib.QueuePool))
+            pool_args = {}
+            # consume pool arguments from kwargs, translating a few of the arguments
+            for k in util.get_cls_kwargs(poolclass):
+                tk = {'echo':'echo_pool', 'timeout':'pool_timeout', 'recycle':'pool_recycle'}.get(k, k)
+                if tk in kwargs:
+                    pool_args[k] = kwargs.pop(tk)
+            pool_args['use_threadlocal'] = self.pool_threadlocal()
+            pool = poolclass(creator, **pool_args)
+        else:
+            if isinstance(pool, poollib.DBProxy):
+                pool = pool.get_pool(*cargs, **cparams)
+            else:
+                pool = pool
+
+        provider = self.get_pool_provider(pool)
+
+        # create engine.
+        engineclass = self.get_engine_cls()
+        engine_args = {}
+        for k in util.get_cls_kwargs(engineclass):
+            if k in kwargs:
+                engine_args[k] = kwargs.pop(k)
+                
+        # all kwargs should be consumed
+        if len(kwargs):
+            raise TypeError("Invalid argument(s) %s sent to create_engine(), using configuration %s/%s/%s.  Please check that the keyword arguments are appropriate for this combination of components." % (','.join(["'%s'" % k for k in kwargs]), dialect.__class__.__name__, pool.__class__.__name__, engineclass.__name__))
+            
+        return engineclass(provider, dialect, **engine_args)
 
     def pool_threadlocal(self):
         raise NotImplementedError()
-    def get_pool_provider(self, dialect, url, **kwargs):
+    def get_pool_provider(self, pool):
         raise NotImplementedError()
-    def get_engine(self, provider, dialect, **kwargs):
+    def get_engine_cls(self):
         raise NotImplementedError()
            
 class PlainEngineStrategy(DefaultEngineStrategy):
@@ -52,10 +99,10 @@ class PlainEngineStrategy(DefaultEngineStrategy):
         DefaultEngineStrategy.__init__(self, 'plain')
     def pool_threadlocal(self):
         return False
-    def get_pool_provider(self, dialect, url, **poolargs):
-        return default.PoolConnectionProvider(dialect, url, **poolargs)
-    def get_engine(self, provider, dialect, **kwargs):
-        return base.Engine(provider, dialect, **kwargs)
+    def get_pool_provider(self, pool):
+        return default.PoolConnectionProvider(pool)
+    def get_engine_cls(self):
+        return base.Engine
 PlainEngineStrategy()
 
 class ThreadLocalEngineStrategy(DefaultEngineStrategy):
@@ -63,10 +110,10 @@ class ThreadLocalEngineStrategy(DefaultEngineStrategy):
         DefaultEngineStrategy.__init__(self, 'threadlocal')
     def pool_threadlocal(self):
         return True
-    def get_pool_provider(self, dialect, url, **poolargs):
-        return threadlocal.TLocalConnectionProvider(dialect, url, **poolargs)
-    def get_engine(self, provider, dialect, **kwargs):
-        return threadlocal.TLEngine(provider, dialect, **kwargs)
+    def get_pool_provider(self, pool):
+        return threadlocal.TLocalConnectionProvider(pool)
+    def get_engine_cls(self):
+        return threadlocal.TLEngine
 ThreadLocalEngineStrategy()
 
 
index 3636d552349d66b9cf43ee52810dc9bcfebd2821..bd40039d4e88b79772d7f1733f05d478d9441ef7 100644 (file)
@@ -57,7 +57,18 @@ class ArgSingleton(type):
             instance = type.__call__(self, *args)
             ArgSingleton.instances[hashkey] = instance
             return instance
-        
+
+def get_cls_kwargs(cls):
+    """return the full set of legal kwargs for the given cls"""
+    kw = []
+    for c in cls.__mro__:
+        cons = c.__init__
+        if hasattr(cons, 'func_code'):
+            for vn in cons.func_code.co_varnames:
+                if vn != 'self':
+                    kw.append(vn)
+    return kw
+                        
 class SimpleProperty(object):
     """a "default" property accessor."""
     def __init__(self, key):
index 9af594a98b6f8f5b1f5d856bc9d9f96744cbbbd0..251e0e64dc4b97d081230e7aa41d38a7bf51fd08 100644 (file)
@@ -67,6 +67,75 @@ class CreateEngineTest(PersistTest):
         e = create_engine('postgres://', pool_recycle=472, module=dbapi)
         assert e.connection_provider._pool._recycle == 472
         
+    def testbadargs(self):
+        # good arg, use MockDBAPI to prevent oracle import errors
+        e = create_engine('oracle://', use_ansi=True, module=MockDBAPI())
+
+        # bad arg
+        try:
+            e = create_engine('postgres://', use_ansi=True, module=MockDBAPI())
+            assert False
+        except TypeError:
+            assert True
+        
+        # bad arg
+        try:
+            e = create_engine('oracle://', lala=5, use_ansi=True, module=MockDBAPI())
+            assert False
+        except TypeError:
+            assert True
+            
+        try:
+            e = create_engine('postgres://', lala=5, module=MockDBAPI())
+            assert False
+        except TypeError:
+            assert True
+        
+        try:
+            e = create_engine('sqlite://', lala=5)
+            assert False
+        except TypeError:
+            assert True
+
+        try:
+            e = create_engine('mysql://', use_unicode=True)
+            assert False
+        except TypeError:
+            assert True
+
+        try:
+            # sqlite uses SingletonThreadPool which doesnt have max_overflow
+            e = create_engine('sqlite://', max_overflow=5)
+            assert False
+        except TypeError:
+            assert True
+            
+        e = create_engine('sqlite://', echo=True)
+        e = create_engine('mysql://', module=MockDBAPI(), connect_args={'use_unicode':True}, convert_unicode=True)
+        
+        e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True)
+        try:
+            c = e.connect()
+            assert False
+        except exceptions.DBAPIError:
+            assert True
+            
+    def testpoolargs(self):
+        """test that connection pool args make it thru"""
+        e = create_engine('postgres://', creator=None, pool_recycle=-1, echo_pool=None, auto_close_cursors=False, disallow_open_cursors=True, module=MockDBAPI())
+        assert e.connection_provider._pool.auto_close_cursors is False
+        assert e.connection_provider._pool.disallow_open_cursors is True
+
+        # these args work for QueuePool
+        e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=pool.QueuePool, module=MockDBAPI())
+
+        try:
+            # but not SingletonThreadPool
+            e = create_engine('sqlite://', max_overflow=8, pool_timeout=60, poolclass=pool.SingletonThreadPool)
+            assert False
+        except TypeError:
+            assert True
+
 class MockDBAPI(object):
     def __init__(self, **kwargs):
         self.kwargs = kwargs
index 509d81aee145ba26f07145ed624b7be94e54723d..ed68efd4d29e5b744a891bb141812fec6a8b760e 100644 (file)
@@ -98,9 +98,9 @@ def parse_argv():
     if options.enginestrategy is not None:
         opts['strategy'] = options.enginestrategy    
     if options.mockpool:
-        db = engine.create_engine(db_uri, default_ordering=True, poolclass=pool.AssertionPool, **opts)
+        db = engine.create_engine(db_uri, poolclass=pool.AssertionPool, **opts)
     else:
-        db = engine.create_engine(db_uri, default_ordering=True, **opts)
+        db = engine.create_engine(db_uri, **opts)
     db = EngineAssert(db)
 
     import logging