From: Mike Bayer Date: Sun, 12 Nov 2006 20:50:51 +0000 (+0000) Subject: - create_engine() reworked to be strict about incoming **kwargs. all keyword X-Git-Tag: rel_0_3_1~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ecee1fb16cf6b7b5d01187191ea23260b8bcef2a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - create_engine() reworked to be strict about incoming **kwargs. all keyword arguments must be consumed by one of the dialect, connection pool, and engine constructors, else a TypeError is thrown which describes the full set of invalid kwargs in relation to the selected dialect/pool/engine configuration. --- diff --git a/CHANGES b/CHANGES index 1949a7f370..82c6016039 100644 --- a/CHANGES +++ b/CHANGES @@ -30,6 +30,10 @@ the target class for [ticket:362] - "delete-orphan" for a certain type can be set on more than one parent class; the instance is an "orphan" only if its not attached to *any* of those parents +- create_engine() reworked to be strict about incoming **kwargs. all keyword +arguments must be consumed by one of the dialect, connection pool, and engine +constructors, else a TypeError is thrown which describes the full set of +invalid kwargs in relation to the selected dialect/pool/engine configuration. 0.3.0 - General: diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 07a88659be..205e5aa02c 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -396,7 +396,7 @@ class Engine(sql.Executor, Connectable): Connects a ConnectionProvider, a Dialect and a CompilerFactory together to provide a default implementation of SchemaEngine. """ - def __init__(self, connection_provider, dialect, echo=None, **kwargs): + def __init__(self, connection_provider, dialect, echo=None): self.connection_provider = connection_provider self.dialect=dialect self.echo = echo diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 02d3e4608a..4af539e784 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -6,7 +6,6 @@ from sqlalchemy import schema, exceptions, util, sql, types -from sqlalchemy import pool as poollib import StringIO, sys, re from sqlalchemy.engine import base @@ -14,30 +13,8 @@ from sqlalchemy.engine import base class PoolConnectionProvider(base.ConnectionProvider): - def __init__(self, dialect, url, poolclass=None, pool=None, **kwargs): - (cargs, cparams) = dialect.create_connect_args(url) - cparams.update(kwargs.pop('connect_args', {})) - - if pool is None: - kwargs.setdefault('echo', False) - kwargs.setdefault('use_threadlocal',True) - if poolclass is None: - poolclass = poollib.QueuePool - dbapi = dialect.dbapi() - if dbapi is None: - raise exceptions.InvalidRequestError("Cant get DBAPI module for dialect '%s'" % dialect) - def connect(): - try: - return dbapi.connect(*cargs, **cparams) - except Exception, e: - raise exceptions.DBAPIError("Connection failed", e) - creator = kwargs.pop('creator', connect) - self._pool = poolclass(creator, **kwargs) - else: - if isinstance(pool, poollib.DBProxy): - self._pool = pool.get_pool(*cargs, **cparams) - else: - self._pool = pool + def __init__(self, pool): + self._pool = pool def get_connection(self): return self._pool.connect() def dispose(self): diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index fe30aeb8d7..d48412160d 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -6,6 +6,8 @@ this can be accomplished via a mod; see the sqlalchemy/mods package for details. from sqlalchemy.engine import base, default, threadlocal, url +from sqlalchemy import util, exceptions +from sqlalchemy import pool as poollib strategies = {} @@ -22,29 +24,74 @@ class EngineStrategy(object): raise NotImplementedError() class DefaultEngineStrategy(EngineStrategy): - def create(self, name_or_url, **kwargs): + def create(self, name_or_url, **kwargs): + # create url.URL object u = url.make_url(name_or_url) + + # get module from sqlalchemy.databases module = u.get_module() - dialect = module.dialect(**kwargs) + dialect_args = {} + # consume dialect arguments from kwargs + for k in util.get_cls_kwargs(module.dialect): + if k in kwargs: + dialect_args[k] = kwargs.pop(k) + + # create dialect + dialect = module.dialect(**dialect_args) - poolargs = {} - for key in (('echo_pool', 'echo'), ('pool_size', 'pool_size'), ('max_overflow', 'max_overflow'), ('poolclass', 'poolclass'), ('pool_timeout','timeout'), ('pool', 'pool'), ('pool_recycle','recycle'),('connect_args', 'connect_args'), ('creator', 'creator')): - if kwargs.has_key(key[0]): - poolargs[key[1]] = kwargs[key[0]] - poolclass = getattr(module, 'poolclass', None) - if poolclass is not None: - poolargs.setdefault('poolclass', poolclass) - poolargs['use_threadlocal'] = self.pool_threadlocal() - provider = self.get_pool_provider(dialect, u, **poolargs) + # assemble connection arguments + (cargs, cparams) = dialect.create_connect_args(u) + cparams.update(kwargs.pop('connect_args', {})) - return self.get_engine(provider, dialect, **kwargs) + # look for existing pool or create + pool = kwargs.pop('pool', None) + if pool is None: + dbapi = kwargs.pop('module', dialect.dbapi()) + if dbapi is None: + raise exceptions.InvalidRequestError("Cant get DBAPI module for dialect '%s'" % dialect) + def connect(): + try: + return dbapi.connect(*cargs, **cparams) + except Exception, e: + raise exceptions.DBAPIError("Connection failed", e) + creator = kwargs.pop('creator', connect) + + poolclass = kwargs.pop('poolclass', getattr(module, 'poolclass', poollib.QueuePool)) + pool_args = {} + # consume pool arguments from kwargs, translating a few of the arguments + for k in util.get_cls_kwargs(poolclass): + tk = {'echo':'echo_pool', 'timeout':'pool_timeout', 'recycle':'pool_recycle'}.get(k, k) + if tk in kwargs: + pool_args[k] = kwargs.pop(tk) + pool_args['use_threadlocal'] = self.pool_threadlocal() + pool = poolclass(creator, **pool_args) + else: + if isinstance(pool, poollib.DBProxy): + pool = pool.get_pool(*cargs, **cparams) + else: + pool = pool + + provider = self.get_pool_provider(pool) + + # create engine. + engineclass = self.get_engine_cls() + engine_args = {} + for k in util.get_cls_kwargs(engineclass): + if k in kwargs: + engine_args[k] = kwargs.pop(k) + + # all kwargs should be consumed + if len(kwargs): + raise TypeError("Invalid argument(s) %s sent to create_engine(), using configuration %s/%s/%s. Please check that the keyword arguments are appropriate for this combination of components." % (','.join(["'%s'" % k for k in kwargs]), dialect.__class__.__name__, pool.__class__.__name__, engineclass.__name__)) + + return engineclass(provider, dialect, **engine_args) def pool_threadlocal(self): raise NotImplementedError() - def get_pool_provider(self, dialect, url, **kwargs): + def get_pool_provider(self, pool): raise NotImplementedError() - def get_engine(self, provider, dialect, **kwargs): + def get_engine_cls(self): raise NotImplementedError() class PlainEngineStrategy(DefaultEngineStrategy): @@ -52,10 +99,10 @@ class PlainEngineStrategy(DefaultEngineStrategy): DefaultEngineStrategy.__init__(self, 'plain') def pool_threadlocal(self): return False - def get_pool_provider(self, dialect, url, **poolargs): - return default.PoolConnectionProvider(dialect, url, **poolargs) - def get_engine(self, provider, dialect, **kwargs): - return base.Engine(provider, dialect, **kwargs) + def get_pool_provider(self, pool): + return default.PoolConnectionProvider(pool) + def get_engine_cls(self): + return base.Engine PlainEngineStrategy() class ThreadLocalEngineStrategy(DefaultEngineStrategy): @@ -63,10 +110,10 @@ class ThreadLocalEngineStrategy(DefaultEngineStrategy): DefaultEngineStrategy.__init__(self, 'threadlocal') def pool_threadlocal(self): return True - def get_pool_provider(self, dialect, url, **poolargs): - return threadlocal.TLocalConnectionProvider(dialect, url, **poolargs) - def get_engine(self, provider, dialect, **kwargs): - return threadlocal.TLEngine(provider, dialect, **kwargs) + def get_pool_provider(self, pool): + return threadlocal.TLocalConnectionProvider(pool) + def get_engine_cls(self): + return threadlocal.TLEngine ThreadLocalEngineStrategy() diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 3636d55234..bd40039d4e 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -57,7 +57,18 @@ class ArgSingleton(type): instance = type.__call__(self, *args) ArgSingleton.instances[hashkey] = instance return instance - + +def get_cls_kwargs(cls): + """return the full set of legal kwargs for the given cls""" + kw = [] + for c in cls.__mro__: + cons = c.__init__ + if hasattr(cons, 'func_code'): + for vn in cons.func_code.co_varnames: + if vn != 'self': + kw.append(vn) + return kw + class SimpleProperty(object): """a "default" property accessor.""" def __init__(self, key): diff --git a/test/engine/parseconnect.py b/test/engine/parseconnect.py index 9af594a98b..251e0e64dc 100644 --- a/test/engine/parseconnect.py +++ b/test/engine/parseconnect.py @@ -67,6 +67,75 @@ class CreateEngineTest(PersistTest): e = create_engine('postgres://', pool_recycle=472, module=dbapi) assert e.connection_provider._pool._recycle == 472 + def testbadargs(self): + # good arg, use MockDBAPI to prevent oracle import errors + e = create_engine('oracle://', use_ansi=True, module=MockDBAPI()) + + # bad arg + try: + e = create_engine('postgres://', use_ansi=True, module=MockDBAPI()) + assert False + except TypeError: + assert True + + # bad arg + try: + e = create_engine('oracle://', lala=5, use_ansi=True, module=MockDBAPI()) + assert False + except TypeError: + assert True + + try: + e = create_engine('postgres://', lala=5, module=MockDBAPI()) + assert False + except TypeError: + assert True + + try: + e = create_engine('sqlite://', lala=5) + assert False + except TypeError: + assert True + + try: + e = create_engine('mysql://', use_unicode=True) + assert False + except TypeError: + assert True + + try: + # sqlite uses SingletonThreadPool which doesnt have max_overflow + e = create_engine('sqlite://', max_overflow=5) + assert False + except TypeError: + assert True + + e = create_engine('sqlite://', echo=True) + e = create_engine('mysql://', module=MockDBAPI(), connect_args={'use_unicode':True}, convert_unicode=True) + + e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True) + try: + c = e.connect() + assert False + except exceptions.DBAPIError: + assert True + + def testpoolargs(self): + """test that connection pool args make it thru""" + e = create_engine('postgres://', creator=None, pool_recycle=-1, echo_pool=None, auto_close_cursors=False, disallow_open_cursors=True, module=MockDBAPI()) + assert e.connection_provider._pool.auto_close_cursors is False + assert e.connection_provider._pool.disallow_open_cursors is True + + # these args work for QueuePool + e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=pool.QueuePool, module=MockDBAPI()) + + try: + # but not SingletonThreadPool + e = create_engine('sqlite://', max_overflow=8, pool_timeout=60, poolclass=pool.SingletonThreadPool) + assert False + except TypeError: + assert True + class MockDBAPI(object): def __init__(self, **kwargs): self.kwargs = kwargs diff --git a/test/testbase.py b/test/testbase.py index 509d81aee1..ed68efd4d2 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -98,9 +98,9 @@ def parse_argv(): if options.enginestrategy is not None: opts['strategy'] = options.enginestrategy if options.mockpool: - db = engine.create_engine(db_uri, default_ordering=True, poolclass=pool.AssertionPool, **opts) + db = engine.create_engine(db_uri, poolclass=pool.AssertionPool, **opts) else: - db = engine.create_engine(db_uri, default_ordering=True, **opts) + db = engine.create_engine(db_uri, **opts) db = EngineAssert(db) import logging