from sqlalchemy import schema, exceptions, util, sql, types
-from sqlalchemy import pool as poollib
import StringIO, sys, re
from sqlalchemy.engine import base
class PoolConnectionProvider(base.ConnectionProvider):
- def __init__(self, dialect, url, poolclass=None, pool=None, **kwargs):
- (cargs, cparams) = dialect.create_connect_args(url)
- cparams.update(kwargs.pop('connect_args', {}))
-
- if pool is None:
- kwargs.setdefault('echo', False)
- kwargs.setdefault('use_threadlocal',True)
- if poolclass is None:
- poolclass = poollib.QueuePool
- dbapi = dialect.dbapi()
- if dbapi is None:
- raise exceptions.InvalidRequestError("Cant get DBAPI module for dialect '%s'" % dialect)
- def connect():
- try:
- return dbapi.connect(*cargs, **cparams)
- except Exception, e:
- raise exceptions.DBAPIError("Connection failed", e)
- creator = kwargs.pop('creator', connect)
- self._pool = poolclass(creator, **kwargs)
- else:
- if isinstance(pool, poollib.DBProxy):
- self._pool = pool.get_pool(*cargs, **cparams)
- else:
- self._pool = pool
+ def __init__(self, pool):
+ self._pool = pool
def get_connection(self):
return self._pool.connect()
def dispose(self):
from sqlalchemy.engine import base, default, threadlocal, url
+from sqlalchemy import util, exceptions
+from sqlalchemy import pool as poollib
strategies = {}
raise NotImplementedError()
class DefaultEngineStrategy(EngineStrategy):
- def create(self, name_or_url, **kwargs):
+ def create(self, name_or_url, **kwargs):
+ # create url.URL object
u = url.make_url(name_or_url)
+
+ # get module from sqlalchemy.databases
module = u.get_module()
- dialect = module.dialect(**kwargs)
+ dialect_args = {}
+ # consume dialect arguments from kwargs
+ for k in util.get_cls_kwargs(module.dialect):
+ if k in kwargs:
+ dialect_args[k] = kwargs.pop(k)
+
+ # create dialect
+ dialect = module.dialect(**dialect_args)
- poolargs = {}
- for key in (('echo_pool', 'echo'), ('pool_size', 'pool_size'), ('max_overflow', 'max_overflow'), ('poolclass', 'poolclass'), ('pool_timeout','timeout'), ('pool', 'pool'), ('pool_recycle','recycle'),('connect_args', 'connect_args'), ('creator', 'creator')):
- if kwargs.has_key(key[0]):
- poolargs[key[1]] = kwargs[key[0]]
- poolclass = getattr(module, 'poolclass', None)
- if poolclass is not None:
- poolargs.setdefault('poolclass', poolclass)
- poolargs['use_threadlocal'] = self.pool_threadlocal()
- provider = self.get_pool_provider(dialect, u, **poolargs)
+ # assemble connection arguments
+ (cargs, cparams) = dialect.create_connect_args(u)
+ cparams.update(kwargs.pop('connect_args', {}))
- return self.get_engine(provider, dialect, **kwargs)
+ # look for existing pool or create
+ pool = kwargs.pop('pool', None)
+ if pool is None:
+ dbapi = kwargs.pop('module', dialect.dbapi())
+ if dbapi is None:
+ raise exceptions.InvalidRequestError("Cant get DBAPI module for dialect '%s'" % dialect)
+ def connect():
+ try:
+ return dbapi.connect(*cargs, **cparams)
+ except Exception, e:
+ raise exceptions.DBAPIError("Connection failed", e)
+ creator = kwargs.pop('creator', connect)
+
+ poolclass = kwargs.pop('poolclass', getattr(module, 'poolclass', poollib.QueuePool))
+ pool_args = {}
+ # consume pool arguments from kwargs, translating a few of the arguments
+ for k in util.get_cls_kwargs(poolclass):
+ tk = {'echo':'echo_pool', 'timeout':'pool_timeout', 'recycle':'pool_recycle'}.get(k, k)
+ if tk in kwargs:
+ pool_args[k] = kwargs.pop(tk)
+ pool_args['use_threadlocal'] = self.pool_threadlocal()
+ pool = poolclass(creator, **pool_args)
+ else:
+ if isinstance(pool, poollib.DBProxy):
+ pool = pool.get_pool(*cargs, **cparams)
+ else:
+ pool = pool
+
+ provider = self.get_pool_provider(pool)
+
+ # create engine.
+ engineclass = self.get_engine_cls()
+ engine_args = {}
+ for k in util.get_cls_kwargs(engineclass):
+ if k in kwargs:
+ engine_args[k] = kwargs.pop(k)
+
+ # all kwargs should be consumed
+ if len(kwargs):
+ raise TypeError("Invalid argument(s) %s sent to create_engine(), using configuration %s/%s/%s. Please check that the keyword arguments are appropriate for this combination of components." % (','.join(["'%s'" % k for k in kwargs]), dialect.__class__.__name__, pool.__class__.__name__, engineclass.__name__))
+
+ return engineclass(provider, dialect, **engine_args)
def pool_threadlocal(self):
raise NotImplementedError()
- def get_pool_provider(self, dialect, url, **kwargs):
+ def get_pool_provider(self, pool):
raise NotImplementedError()
- def get_engine(self, provider, dialect, **kwargs):
+ def get_engine_cls(self):
raise NotImplementedError()
class PlainEngineStrategy(DefaultEngineStrategy):
DefaultEngineStrategy.__init__(self, 'plain')
def pool_threadlocal(self):
return False
- def get_pool_provider(self, dialect, url, **poolargs):
- return default.PoolConnectionProvider(dialect, url, **poolargs)
- def get_engine(self, provider, dialect, **kwargs):
- return base.Engine(provider, dialect, **kwargs)
+ def get_pool_provider(self, pool):
+ return default.PoolConnectionProvider(pool)
+ def get_engine_cls(self):
+ return base.Engine
PlainEngineStrategy()
class ThreadLocalEngineStrategy(DefaultEngineStrategy):
DefaultEngineStrategy.__init__(self, 'threadlocal')
def pool_threadlocal(self):
return True
- def get_pool_provider(self, dialect, url, **poolargs):
- return threadlocal.TLocalConnectionProvider(dialect, url, **poolargs)
- def get_engine(self, provider, dialect, **kwargs):
- return threadlocal.TLEngine(provider, dialect, **kwargs)
+ def get_pool_provider(self, pool):
+ return threadlocal.TLocalConnectionProvider(pool)
+ def get_engine_cls(self):
+ return threadlocal.TLEngine
ThreadLocalEngineStrategy()
e = create_engine('postgres://', pool_recycle=472, module=dbapi)
assert e.connection_provider._pool._recycle == 472
+ def testbadargs(self):
+ # good arg, use MockDBAPI to prevent oracle import errors
+ e = create_engine('oracle://', use_ansi=True, module=MockDBAPI())
+
+ # bad arg
+ try:
+ e = create_engine('postgres://', use_ansi=True, module=MockDBAPI())
+ assert False
+ except TypeError:
+ assert True
+
+ # bad arg
+ try:
+ e = create_engine('oracle://', lala=5, use_ansi=True, module=MockDBAPI())
+ assert False
+ except TypeError:
+ assert True
+
+ try:
+ e = create_engine('postgres://', lala=5, module=MockDBAPI())
+ assert False
+ except TypeError:
+ assert True
+
+ try:
+ e = create_engine('sqlite://', lala=5)
+ assert False
+ except TypeError:
+ assert True
+
+ try:
+ e = create_engine('mysql://', use_unicode=True)
+ assert False
+ except TypeError:
+ assert True
+
+ try:
+ # sqlite uses SingletonThreadPool which doesnt have max_overflow
+ e = create_engine('sqlite://', max_overflow=5)
+ assert False
+ except TypeError:
+ assert True
+
+ e = create_engine('sqlite://', echo=True)
+ e = create_engine('mysql://', module=MockDBAPI(), connect_args={'use_unicode':True}, convert_unicode=True)
+
+ e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True)
+ try:
+ c = e.connect()
+ assert False
+ except exceptions.DBAPIError:
+ assert True
+
+ def testpoolargs(self):
+ """test that connection pool args make it thru"""
+ e = create_engine('postgres://', creator=None, pool_recycle=-1, echo_pool=None, auto_close_cursors=False, disallow_open_cursors=True, module=MockDBAPI())
+ assert e.connection_provider._pool.auto_close_cursors is False
+ assert e.connection_provider._pool.disallow_open_cursors is True
+
+ # these args work for QueuePool
+ e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=pool.QueuePool, module=MockDBAPI())
+
+ try:
+ # but not SingletonThreadPool
+ e = create_engine('sqlite://', max_overflow=8, pool_timeout=60, poolclass=pool.SingletonThreadPool)
+ assert False
+ except TypeError:
+ assert True
+
class MockDBAPI(object):
def __init__(self, **kwargs):
self.kwargs = kwargs