From: Mike Bayer Date: Tue, 20 Oct 2009 17:33:33 +0000 (+0000) Subject: merged scopefunc patch from r6420 of 0.5 branch X-Git-Tag: rel_0_6beta1~237 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=aceb90525aad6d3e807f5f2db548353a5fc33138;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git merged scopefunc patch from r6420 of 0.5 branch --- diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index f00b30849d..a8ed9c9108 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -5,7 +5,8 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php import sqlalchemy.exceptions as sa_exc -from sqlalchemy.util import ScopedRegistry, to_list, get_cls_kwargs, deprecated +from sqlalchemy.util import ScopedRegistry, ThreadLocalRegistry, \ + to_list, get_cls_kwargs, deprecated from sqlalchemy.orm import ( EXT_CONTINUE, MapperExtension, class_mapper, object_session ) @@ -29,7 +30,10 @@ class ScopedSession(object): def __init__(self, session_factory, scopefunc=None): self.session_factory = session_factory - self.registry = ScopedRegistry(session_factory, scopefunc) + if scopefunc: + self.registry = ScopedRegistry(session_factory, scopefunc) + else: + self.registry = ThreadLocalRegistry(session_factory) self.extension = _ScopedExt(self) def __call__(self, **kwargs): diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 8f0b5583dd..da426cbd80 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -1180,14 +1180,7 @@ class ScopedRegistry(object): scopefunc a callable that will return a key to store/retrieve an object. - If None, ScopedRegistry uses a threading.local object instead. - """ - def __new__(cls, createfunc, scopefunc=None): - if not scopefunc: - return object.__new__(_TLocalRegistry) - else: - return object.__new__(cls) def __init__(self, createfunc, scopefunc): self.createfunc = createfunc @@ -1213,8 +1206,8 @@ class ScopedRegistry(object): except KeyError: pass -class _TLocalRegistry(ScopedRegistry): - def __init__(self, createfunc, scopefunc=None): +class ThreadLocalRegistry(ScopedRegistry): + def __init__(self, createfunc): self.createfunc = createfunc self.registry = threading.local()