]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The "extension" argument to Session and others can now
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 29 Aug 2008 16:31:58 +0000 (16:31 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 29 Aug 2008 16:31:58 +0000 (16:31 +0000)
optionally be a list, supporting events sent to multiple
SessionExtension instances.  Session places SessionExtensions
in Session.extensions.

CHANGES
lib/sqlalchemy/orm/session.py

diff --git a/CHANGES b/CHANGES
index 9b163c7c937e02bd0e8a12375ea723e16b60617b..989fa0200ba6099e69ecd8af8feb3262bb357320 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -49,6 +49,11 @@ CHANGES
       change the state of the Session before the flush proceeds.
       [ticket:1128]
 
+    - The "extension" argument to Session and others can now
+      optionally be a list, supporting events sent to multiple
+      SessionExtension instances.  Session places SessionExtensions
+      in Session.extensions.
+      
     - Reentrant calls to flush() raise an error.  This also serves
       as a rudimentary, but not foolproof, check against concurrent
       calls to Session.flush().
index b46ae16d3213ace6ccb72067907090645f6cf76f..7195d5b1b027f773f238117550f559f95d7e576b 100644 (file)
@@ -139,7 +139,8 @@ def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False,
       ``logging.getLogger('sqlalchemy.orm.unitofwork').setLevel(logging.DEBUG)``.
 
     extension
-      An optional [sqlalchemy.orm.session#SessionExtension] instance, which
+      An optional [sqlalchemy.orm.session#SessionExtension] instance, or
+      a list of such instances, which
       will receive pre- and post- commit and flush events, as well as a
       post-rollback event.  User- defined code may be placed within these
       hooks using a user-defined subclass of ``SessionExtension``.
@@ -333,8 +334,8 @@ class SessionTransaction(object):
 
         self._connections[conn] = self._connections[conn.engine] = \
           (conn, transaction, conn is not bind)
-        if self.session.extension is not None:
-            self.session.extension.after_begin(self.session, self, conn)
+        for ext in self.session.extensions:
+            ext.after_begin(self.session, self, conn)
         return conn
 
     def prepare(self):
@@ -345,9 +346,9 @@ class SessionTransaction(object):
 
     def _prepare_impl(self):
         self._assert_is_active()
-        if (self.session.extension is not None and
-            (self._parent is None or self.nested)):
-            self.session.extension.before_commit(self.session)
+        if self._parent is None or self.nested:
+            for ext in self.session.extensions:
+                ext.before_commit(self.session)
 
         stx = self.session.transaction
         if stx is not self:
@@ -377,8 +378,8 @@ class SessionTransaction(object):
             for t in set(self._connections.values()):
                 t[1].commit()
 
-            if self.session.extension is not None:
-                self.session.extension.after_commit(self.session)
+            for ext in self.session.extensions:
+                ext.after_commit(self.session)
 
             if self.session._enable_transaction_accounting:
                 self._remove_snapshot()
@@ -413,8 +414,8 @@ class SessionTransaction(object):
         if self.session._enable_transaction_accounting:
             self._restore_snapshot()
 
-        if self.session.extension is not None:
-            self.session.extension.after_rollback(self.session)
+        for ext in self.session.extensions:
+            ext.after_rollback(self.session)
 
     def _deactivate(self):
         self._active = False
@@ -558,7 +559,7 @@ class Session(object):
         self.expire_on_commit = expire_on_commit
         self._enable_transaction_accounting = _enable_transaction_accounting
         self.twophase = twophase
-        self.extension = extension
+        self.extensions = util.to_list(extension) or []
         self._query_cls = query_cls
         self._mapper_flush_opts = {}
 
@@ -1303,8 +1304,8 @@ class Session(object):
                                     state.session_id, self.hash_key))
         if state.session_id != self.hash_key:
             state.session_id = self.hash_key
-        if self.extension is not None:
-            self.extension.after_attach(self, state.obj())
+        for ext in self.extensions:
+            ext.after_attach(self, state.obj())
 
     def __contains__(self, instance):
         """Return True if the instance is associated with this session.
@@ -1371,8 +1372,8 @@ class Session(object):
 
         flush_context = UOWTransaction(self)
 
-        if self.extension is not None:
-            self.extension.before_flush(self, flush_context, objects)
+        for ext in self.extensions:
+            ext.before_flush(self, flush_context, objects)
 
         deleted = set(self._deleted)
         new = set(self._new)
@@ -1425,8 +1426,8 @@ class Session(object):
         try:
             flush_context.execute()
 
-            if self.extension is not None:
-                self.extension.after_flush(self, flush_context)
+            for ext in self.extensions:
+                ext.after_flush(self, flush_context)
             transaction.commit()
         except:
             transaction.rollback()
@@ -1437,8 +1438,8 @@ class Session(object):
         if not objects:
             self.identity_map.modified = False
 
-        if self.extension is not None:
-            self.extension.after_flush_postexec(self, flush_context)
+        for ext in self.extensions:
+            ext.after_flush_postexec(self, flush_context)
 
     def is_modified(self, instance, include_collections=True, passive=False):
         """Return True if instance has modified attributes.