From ef89f1d734f8ef64640aa7a2b0a4a1cfc5fcb49d Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 20 Oct 2009 17:32:44 +0000 Subject: [PATCH] Gaetan's "scopefunc" clarification patch --- lib/sqlalchemy/orm/scoping.py | 8 ++++++-- lib/sqlalchemy/util.py | 11 ++--------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 4339b68ebc..fa4d3abd2f 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -5,7 +5,8 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php import sqlalchemy.exceptions as sa_exc -from sqlalchemy.util import ScopedRegistry, to_list, get_cls_kwargs, deprecated +from sqlalchemy.util import ScopedRegistry, ThreadLocalRegistry, \ + to_list, get_cls_kwargs, deprecated from sqlalchemy.orm import ( EXT_CONTINUE, MapperExtension, class_mapper, object_session ) @@ -29,7 +30,10 @@ class ScopedSession(object): def __init__(self, session_factory, scopefunc=None): self.session_factory = session_factory - self.registry = ScopedRegistry(session_factory, scopefunc) + if scopefunc: + self.registry = ScopedRegistry(session_factory, scopefunc) + else: + self.registry = ThreadLocalRegistry(session_factory) self.extension = _ScopedExt(self) def __call__(self, **kwargs): diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 8eeeda4555..308fe19a82 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -1124,14 +1124,7 @@ class ScopedRegistry(object): scopefunc a callable that will return a key to store/retrieve an object. - If None, ScopedRegistry uses a threading.local object instead. - """ - def __new__(cls, createfunc, scopefunc=None): - if not scopefunc: - return object.__new__(_TLocalRegistry) - else: - return object.__new__(cls) def __init__(self, createfunc, scopefunc): self.createfunc = createfunc @@ -1157,8 +1150,8 @@ class ScopedRegistry(object): except KeyError: pass -class _TLocalRegistry(ScopedRegistry): - def __init__(self, createfunc, scopefunc=None): +class ThreadLocalRegistry(ScopedRegistry): + def __init__(self, createfunc): self.createfunc = createfunc self.registry = threading.local() -- 2.47.3