]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
objectstore refactored to have more flexible scopes for UnitOfWork
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 10 Feb 2006 23:46:42 +0000 (23:46 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 10 Feb 2006 23:46:42 +0000 (23:46 +0000)
central access point is now a Session object which maintains different
kinds of scopes for collections of one or more UnitOfWork objects
individual object instances get bound to a specific Session

lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/mapping/objectstore.py
lib/sqlalchemy/util.py

index 8fa17035364f712e31f208ed0c55233b7a2dc830..9e110b67b527463285f3733572d36014856884ba 100644 (file)
@@ -205,6 +205,7 @@ class Mapper(object):
             oldinit = self.class_.__init__
             def init(self, *args, **kwargs):
                 nohist = kwargs.pop('_mapper_nohistory', False)
+                session = kwargs.pop('_sa_session', objectstore.session())
                 if oldinit is not None:
                     try:
                         oldinit(self, *args, **kwargs)
@@ -212,7 +213,7 @@ class Mapper(object):
                         # re-raise with the offending class name added to help in debugging
                         raise TypeError, '%s.%s' %(self.__class__.__name__, msg)
                 if not nohist:
-                    objectstore.uow().register_new(self)
+                    session.register_new(self)
             self.class_.__init__ = init
         mapper_registry[self.class_] = self
         self.class_.c = self.c
@@ -245,7 +246,7 @@ class Mapper(object):
                 
         # store new stuff in the identity map
         for value in imap.values():
-            objectstore.uow().register_clean(value)
+            objectstore.session().register_clean(value)
 
         if len(mappers):
             return [result] + otherresults
@@ -262,7 +263,7 @@ class Mapper(object):
         
     def _get(self, key, ident=None):
         try:
-            return objectstore.uow()._get(key)
+            return objectstore.session()._get(key)
         except KeyError:
             if ident is None:
                 ident = key[2]
@@ -676,8 +677,8 @@ class Mapper(object):
         # including modifying any of its related items lists, as its already
         # been exposed to being modified by the application.
         identitykey = self._identity_key(row)
-        if objectstore.uow().has_key(identitykey):
-            instance = objectstore.uow()._get(identitykey)
+        if objectstore.session().has_key(identitykey):
+            instance = objectstore.session()._get(identitykey)
 
             isnew = False
             if populate_existing:
index f7c7a206cd2ec166485615cdb93051f0265073d2..8a8b975ad0c8fcbb66dcfbf986615df1c1e8f28f 100644 (file)
@@ -25,115 +25,180 @@ __all__ = ['get_id_key', 'get_row_key', 'is_dirty', 'import_instance', 'commit',
 # printed to standard output.  also can be affected by creating an engine
 # with the "echo_uow=True" keyword argument.
 LOG = False
+    
+class Session(object):
+    """a scope-managed proxy to UnitOfWork instances.  Operations are delegated
+    to UnitOfWork objects which are accessed via a sqlalchemy.util.ScopedRegistry object.  
+    The registry is capable of maintaining object instances on a thread-local, 
+    per-application, or custom user-defined basis."""
+    
+    def __init__(self, uow=None, registry=None, hash_key=None):
+        """Initialize the objectstore with a UnitOfWork registry.  If called
+        with no arguments, creates a single UnitOfWork for all operations.
+        
+        registry - a sqlalchemy.util.ScopedRegistry to produce UnitOfWork instances.
+        This argument should not be used with the uow argument.
+        uow - a UnitOfWork to use for all operations.  this argument should not be
+        used with the registry argument.
+        hash_key - the hash_key used to identify objects against this session, which 
+        defaults to the id of the Session instance.
+        
+        """
+        if registry is None:
+            if uow is None:
+                uow = UnitOfWork(self)
+            self.registry = util.ScopedRegistry(lambda:uow, 'application')
+        else:
+            self.registry = registry
+        self._hash_key = hash_key
 
-def get_id_key(ident, class_, table):
-    """returns an identity-map key for use in storing/retrieving an item from the identity
-    map, given a tuple of the object's primary key values.
+    def get_id_key(ident, class_, table):
+        """returns an identity-map key for use in storing/retrieving an item from the identity
+        map, given a tuple of the object's primary key values.
+
+        ident - a tuple of primary key values corresponding to the object to be stored.  these
+        values should be in the same order as the primary keys of the table 
+
+        class_ - a reference to the object's class
+
+        table - a Table object where the object's primary fields are stored.
+
+        selectable - a Selectable object which represents all the object's column-based fields.
+        this Selectable may be synonymous with the table argument or can be a larger construct
+        containing that table. return value: a tuple object which is used as an identity key. """
+        return (class_, table.hash_key(), tuple(ident))
+    get_id_key = staticmethod(get_id_key)
+
+    def get_row_key(row, class_, table, primary_key):
+        """returns an identity-map key for use in storing/retrieving an item from the identity
+        map, given a result set row.
+
+        row - a sqlalchemy.dbengine.RowProxy instance or other map corresponding result-set
+        column names to their values within a row.
 
-    ident - a tuple of primary key values corresponding to the object to be stored.  these
-    values should be in the same order as the primary keys of the table 
+        class_ - a reference to the object's class
+
+        table - a Table object where the object's primary fields are stored.
+
+        selectable - a Selectable object which represents all the object's column-based fields.
+        this Selectable may be synonymous with the table argument or can be a larger construct
+        containing that table. return value: a tuple object which is used as an identity key.
+        """
+        return (class_, table.hash_key(), tuple([row[column] for column in primary_key]))
+    get_row_key = staticmethod(get_row_key)
+    
+    def _set_uow(self, uow):
+        self.registry.set(uow)
+    uow = property(lambda s:s.registry(), _set_uow, doc="Returns a scope-specific UnitOfWork object for this session.")
     
-    class_ - a reference to the object's class
+    hash_key = property(lambda s:s._hash_key or id(s))
 
-    table - a Table object where the object's primary fields are stored.
+    def bind_to(self, obj):
+        """given an object, binds it to this session.  changes on the object will affect
+        the currently scoped UnitOfWork maintained by this session."""
+        obj._sa_session_id = self.hash_key
 
-    selectable - a Selectable object which represents all the object's column-based fields.
-    this Selectable may be synonymous with the table argument or can be a larger construct
-    containing that table. return value: a tuple object which is used as an identity key. """
-    return (class_, table.hash_key(), tuple(ident))
-def get_row_key(row, class_, table, primary_key):
-    """returns an identity-map key for use in storing/retrieving an item from the identity
-    map, given a result set row.
+    def __getattr__(self, key):
+        """proxy other methods to our underlying UnitOfWork"""
+        return getattr(self.registry(), key)
 
-    row - a sqlalchemy.dbengine.RowProxy instance or other map corresponding result-set
-    column names to their values within a row.
+    def clear(self):
+        self.registry.clear()
 
-    class_ - a reference to the object's class
+    def delete(*obj):
+        """registers the given objects as to be deleted upon the next commit"""
+        u = registry()
+        for o in obj:
+            u.register_deleted(o)
+        
+    def import_instance(self, instance):
+        """places the given instance in the current thread's unit of work context,
+        either in the current IdentityMap or marked as "new".  Returns either the object
+        or the current corresponding version in the Identity Map.
+
+        this method should be used for any object instance that is coming from a serialized
+        storage, from another thread (assuming the regular threaded unit of work model), or any
+        case where the instance was loaded/created corresponding to a different base unitofwork
+        than the current one."""
+        if instance is None:
+            return None
+        key = getattr(instance, '_instance_key', None)
+        mapper = object_mapper(instance)
+        key = (key[0], mapper.table.hash_key(), key[2])
+        u = self.registry()
+        if key is not None:
+            if u.identity_map.has_key(key):
+                return u.identity_map[key]
+            else:
+                instance._instance_key = key
+                u.identity_map[key] = instance
+                self.bind_to(instance)
+        else:
+            u.register_new(instance)
+        return instance
+    
 
-    table - a Table object where the object's primary fields are stored.
+def get_id_key(ident, class_, table):
+    return Session.get_id_key(ident, class_, table)
 
-    selectable - a Selectable object which represents all the object's column-based fields.
-    this Selectable may be synonymous with the table argument or can be a larger construct
-    containing that table. return value: a tuple object which is used as an identity key.
-    """
-    return (class_, table.hash_key(), tuple([row[column] for column in primary_key]))
+def get_row_key(row, class_, table, primary_key):
+    return Session.get_row_key(row, class_, table, primary_key)
 
 def begin():
     """begins a new UnitOfWork transaction.  the next commit will affect only
     objects that are created, modified, or deleted following the begin statement."""
-    uow().begin()
-    
+    session().begin()
+
 def commit(*obj):
     """commits the current UnitOfWork transaction.  if a transaction was begun 
     via begin(), commits only those objects that were created, modified, or deleted
     since that begin statement.  otherwise commits all objects that have been
     changed."""
-    uow().commit(*obj)
-    
+    session().commit(*obj)
+
 def clear():
     """removes all current UnitOfWorks and IdentityMaps for this thread and 
     establishes a new one.  It is probably a good idea to discard all
     current mapped object instances, as they are no longer in the Identity Map."""
-    uow.set(UnitOfWork())
+    session().clear()
 
 def delete(*obj):
     """registers the given objects as to be deleted upon the next commit"""
-    uw = uow()
+    s = session()
     for o in obj:
-        uw.register_deleted(o)
-    
+        s.register_deleted(o)
+
 def has_key(key):
     """returns True if the current thread-local IdentityMap contains the given instance key"""
-    return uow().identity_map.has_key(key)
+    return session().has_key(key)
 
 def has_instance(instance):
     """returns True if the current thread-local IdentityMap contains the given instance"""
-    return uow().identity_map.has_key(instance_key(instance))
+    return session().has_instance(instance)
 
 def is_dirty(obj):
     """returns True if the given object is in the current UnitOfWork's new or dirty list,
     or if its a modified list attribute on an object."""
-    return uow().is_dirty(obj)
-    
+    return session().is_dirty(obj)
+
 def instance_key(instance):
     """returns the IdentityMap key for the given instance"""
-    return object_mapper(instance).instance_key(instance)
+    return session().instance_key(instance)
 
 def import_instance(instance):
-    """places the given instance in the current thread's unit of work context,
-    either in the current IdentityMap or marked as "new".  Returns either the object
-    or the current corresponding version in the Identity Map.
-    
-    this method should be used for any object instance that is coming from a serialized
-    storage, from another thread (assuming the regular threaded unit of work model), or any
-    case where the instance was loaded/created corresponding to a different base unitofwork
-    than the current one."""
-    if instance is None:
-        return None
-    key = getattr(instance, '_instance_key', None)
-    mapper = object_mapper(instance)
-    key = (key[0], mapper.table.hash_key(), key[2])
-    u = uow()
-    if key is not None:
-        if u.identity_map.has_key(key):
-            return u.identity_map[key]
-        else:
-            instance._instance_key = key
-            u.identity_map[key] = instance
-    else:
-        u.register_new(instance)
-    return instance
-    
+    return session().import_instance(instance)
+
 class UOWListElement(attributes.ListElement):
     def __init__(self, obj, key, data=None, deleteremoved=False, **kwargs):
         attributes.ListElement.__init__(self, obj, key, data=data, **kwargs)
         self.deleteremoved = deleteremoved
     def list_value_changed(self, obj, key, item, listval, isdelete):
-        if not isdelete and uow().deleted.contains(item):
+        sess = session(obj)
+        if not isdelete and sess.deleted.contains(item):
             raise "re-inserting a deleted value into a list"
-        uow().modified_lists.append(self)
+        sess.modified_lists.append(self)
         if self.deleteremoved and isdelete:
-            uow().register_deleted(item)
+            sess.register_deleted(item)
     def append(self, item, _mapper_nohistory = False):
         if _mapper_nohistory:
             self.append_nohistory(item)
@@ -146,15 +211,16 @@ class UOWAttributeManager(attributes.AttributeManager):
         
     def value_changed(self, obj, key, value):
         if hasattr(obj, '_instance_key'):
-            uow().register_dirty(obj)
+            session(obj).register_dirty(obj)
         else:
-            uow().register_new(obj)
+            session(obj).register_new(obj)
 
     def create_list(self, obj, key, list_, **kwargs):
         return UOWListElement(obj, key, list_, **kwargs)
         
 class UnitOfWork(object):
-    def __init__(self, parent = None, is_begun = False):
+    def __init__(self, session, parent=None, is_begun=False):
+        self.session = session
         self.is_begun = is_begun
         if is_begun:
             self.begin_count = 1
@@ -234,10 +300,12 @@ class UnitOfWork(object):
         if not hasattr(obj, '_instance_key'):
             mapper = object_mapper(obj)
             obj._instance_key = mapper.instance_key(obj)
+        self.session.bind_to(obj)
         self._put(obj._instance_key, obj)
         self.attributes.commit(obj)
         
     def register_new(self, obj):
+        self.session.bind_to(obj)
         self.new.append(obj)
         
     def register_dirty(self, obj):
@@ -267,8 +335,8 @@ class UnitOfWork(object):
         if self.is_begun:
             self.begin_count += 1
             return
-        u = UnitOfWork(selfTrue)
-        uow.set(u)
+        u = UnitOfWork(self.session, parent=self, is_begun=True)
+        self.session.registry.set(u)
         
     def commit(self, *objects):
         if self.is_begun:
@@ -330,7 +398,7 @@ class UnitOfWork(object):
         commit_context.post_exec()
         
         if self.parent:
-            uow.set(self.parent)
+            self.session.registry.set(self.parent)
 
     def rollback_object(self, obj):
         self.attributes.rollback(obj)
@@ -341,7 +409,7 @@ class UnitOfWork(object):
         # roll back attributes ?  nah....
         #for obj in self.deleted + self.dirty + self.new:
         #    self.attributes.rollback(obj)
-        uow.set(self.parent)
+        self.session.registry.set(self.parent)
             
 class UOWTransaction(object):
     """handles the details of organizing and executing transaction tasks 
@@ -908,4 +976,42 @@ def object_mapper(obj):
     return sqlalchemy.mapperlib.object_mapper(obj)
 
 global_attributes = UOWAttributeManager()
-uow = util.ScopedRegistry(lambda: UnitOfWork(), "thread")
+
+thread_session = Session(registry=util.ScopedRegistry(lambda: UnitOfWork(thread_session), "thread"), hash_key='thread')
+uow = thread_session.registry # Note: this is not a UnitOfWork, it is a ScopedRegistry that manages UnitOfWork objects
+
+_sessions = weakref.WeakValueDictionary()
+_sessions[thread_session.hash_key] = thread_session
+
+def session(obj=None):
+    # object-specific session ?
+    if obj is not None:
+        # does it have a hash key ?
+        hashkey = getattr(obj, '_sa_session_id', None)
+        if hashkey is not None:
+            # ok, return that
+            try:
+                return _sessions[hashkey]
+            except KeyError:
+                # oh, its gone, nevermind
+                pass
+
+    try:
+        # have a thread-locally defined session (via using_session) ?
+        return _sessions[thread.get_ident()]
+    except KeyError:
+        # nope, return the regular session
+        return thread_session
+    
+def using_session(sess, func):
+    old = _sessions.get(thread.get_ident(), None)
+    try:
+        _sessions[sess.hash_key] = sess
+        _sessions[thread.get_ident()] = sess
+        return func()
+    finally:
+        if old is not None:
+            _session[thread.get_ident()] = old
+        else:
+            del _session[thread.get_ident()]
+
index 665ab4f537930151686b1e237ecd03e58038f7d5..45177f838bec9434fb704fa610b10b6dd7828fc9 100644 (file)
@@ -357,56 +357,45 @@ class HistoryArraySet(UserList.UserList):
         
 class ScopedRegistry(object):
     """a Registry that can store one or multiple instances of a single class 
-    on a per-application or per-thread scoped basis"""
+    on a per-application or per-thread scoped basis
+    
+    createfunc - a callable that returns a new object to be placed in the registry
+    defaultscope - the default scope to be used ('application', 'thread', or 'session')
+    """
     def __init__(self, createfunc, defaultscope):
         self.createfunc = createfunc
         self.defaultscope = defaultscope
-        self.application = createfunc()
-        self.threadlocal = {}
         self.scopes = {
-            'application' : {'call' : self._call_application, 'clear' : self._clear_application, 'set':self._set_application}, 
-            'thread' : {'call' : self._call_thread, 'clear':self._clear_thread, 'set':self._set_thread}
-            }
-
-    def __call__(self, scope = None):
-        if scope is None:
-            scope = self.defaultscope
-        return self.scopes[scope]['call']()
+            "application": lambda:None,
+            "thread": thread.get_ident,
+        }
+        self.registry = {}
 
-    def set(self, obj, scope = None):
-        if scope is None:
-            scope = self.defaultscope
-        return self.scopes[scope]['set'](obj)
-        
-    def clear(self, scope = None):
-        if scope is None:
-            scope = self.defaultscope
-        return self.scopes[scope]['clear']()
+    def add_scope(self, scope, keyfunc, default=True):
+        self.scopes[scope] = keyfunc
+        if default:
+            self.defaultscope = scope
 
-    def _set_thread(self, obj):
-        self.threadlocal[thread.get_ident()] = obj
-    
-    def _call_thread(self):
+    def __call__(self, scope=None):
+        key = self._get_key(scope)
         try:
-            return self.threadlocal[thread.get_ident()]
+            return self.registry[key]
         except KeyError:
-            return self.threadlocal.setdefault(thread.get_ident(), self.createfunc())
+            return self.registry.setdefault(key, self.createfunc())
 
-    def _clear_thread(self):
+    def set(self, obj, scope=None):
+        self.registry[self._get_key(scope)] = obj
+        
+    def clear(self, scope=None):
         try:
-            del self.threadlocal[thread.get_ident()]
+            del self.registry[self._get_key(scope)]
         except KeyError:
             pass
 
-    def _set_application(self, obj):
-        self.application = obj
-        
-    def _call_application(self):
-        return self.application
-
-    def _clear_application(self):
-        self.application = createfunc()
-                
+    def _get_key(self, scope, *args, **kwargs):
+        if scope is None:
+            scope = self.defaultscope
+        return (scope, self.scopes[scope]())
 
 
 def constructor_args(instance, **kwargs):