]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Gaetan's "scopefunc" clarification patch
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 20 Oct 2009 17:32:44 +0000 (17:32 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 20 Oct 2009 17:32:44 +0000 (17:32 +0000)
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/util.py

index 4339b68ebc5f4de4c12a88534deee5abd1c21293..fa4d3abd2f14f4b69b8f769aa38a37dbd7665307 100644 (file)
@@ -5,7 +5,8 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 import sqlalchemy.exceptions as sa_exc
-from sqlalchemy.util import ScopedRegistry, to_list, get_cls_kwargs, deprecated
+from sqlalchemy.util import ScopedRegistry, ThreadLocalRegistry, \
+                            to_list, get_cls_kwargs, deprecated
 from sqlalchemy.orm import (
     EXT_CONTINUE, MapperExtension, class_mapper, object_session
     )
@@ -29,7 +30,10 @@ class ScopedSession(object):
 
     def __init__(self, session_factory, scopefunc=None):
         self.session_factory = session_factory
-        self.registry = ScopedRegistry(session_factory, scopefunc)
+        if scopefunc:
+            self.registry = ScopedRegistry(session_factory, scopefunc)
+        else:
+            self.registry = ThreadLocalRegistry(session_factory)
         self.extension = _ScopedExt(self)
 
     def __call__(self, **kwargs):
index 8eeeda45559ce2bd378dde2a7869929e3e013bdf..308fe19a821113350c97686d7e7f7263899286aa 100644 (file)
@@ -1124,14 +1124,7 @@ class ScopedRegistry(object):
 
     scopefunc
       a callable that will return a key to store/retrieve an object.
-      If None, ScopedRegistry uses a threading.local object instead.
-
     """
-    def __new__(cls, createfunc, scopefunc=None):
-        if not scopefunc:
-            return object.__new__(_TLocalRegistry)
-        else:
-            return object.__new__(cls)
 
     def __init__(self, createfunc, scopefunc):
         self.createfunc = createfunc
@@ -1157,8 +1150,8 @@ class ScopedRegistry(object):
         except KeyError:
             pass
 
-class _TLocalRegistry(ScopedRegistry):
-    def __init__(self, createfunc, scopefunc=None):
+class ThreadLocalRegistry(ScopedRegistry):
+    def __init__(self, createfunc):
         self.createfunc = createfunc
         self.registry = threading.local()