]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [bug] Added __reduce__ to StatementError,
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 10 Jan 2012 16:15:11 +0000 (11:15 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 10 Jan 2012 16:15:11 +0000 (11:15 -0500)
DBAPIError so that exceptions are pickleable,
as when using multiprocessing.  However, not
all DBAPIs support this yet, such as
psycopg2. [ticket:2371]

CHANGES
lib/sqlalchemy/exc.py
test/engine/test_execute.py

diff --git a/CHANGES b/CHANGES
index a55d5953f650898d5472bbfd25947b76980cd052..2457c966893acfc97e280c3aeb4f8da5741a33e0 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -23,6 +23,13 @@ CHANGES
     criteria which will join via AND, i.e.
     query.filter(x==y, z>q, ...)
 
+- engine
+  - [bug] Added __reduce__ to StatementError, 
+    DBAPIError so that exceptions are pickleable,
+    as when using multiprocessing.  However, not 
+    all DBAPIs support this yet, such as 
+    psycopg2. [ticket:2371]
+
 - sqlite
   - [bug] the "name" of an FK constraint in SQLite
     is reflected as "None", not "0" or other 
index 59a15f07981d722f4272fe9b0b4b47f02ba94e0e..55205b2744290b971e721e4ef8703007b47383fb 100644 (file)
@@ -164,6 +164,10 @@ class StatementError(SQLAlchemyError):
         self.params = params
         self.orig = orig
 
+    def __reduce__(self):
+        return self.__class__, (self.message, self.statement, 
+                                self.params, self.orig)
+
     def __str__(self):
         from sqlalchemy.sql import util
         params_repr = util._repr_params(self.params, 10)
@@ -219,6 +223,10 @@ class DBAPIError(StatementError):
 
         return cls(statement, params, orig, connection_invalidated)
 
+    def __reduce__(self):
+        return self.__class__, (self.statement, self.params, 
+                    self.orig, self.connection_invalidated)
+
     def __init__(self, statement, params, orig, connection_invalidated=False):
         try:
             text = str(orig)
index 49457644acd1c882d2dd01b464af1f3efcf09b44..eaef9e43a46c9ab8cb332537c6081496b9e48ddb 100644 (file)
@@ -1,5 +1,6 @@
 from test.lib.testing import eq_, assert_raises, assert_raises_message, config
 import re
+from test.lib.util import picklers
 from sqlalchemy.interfaces import ConnectionProxy
 from sqlalchemy import MetaData, Integer, String, INT, VARCHAR, func, \
     bindparam, select, event, TypeDecorator
@@ -188,6 +189,42 @@ class ExecuteTest(fixtures.TestBase):
         finally:
             conn.close()
 
+    def test_stmt_exception_pickleable_no_dbapi(self):
+        self._test_stmt_exception_pickleable(Exception("hello world"))
+
+    @testing.fails_on("postgresql+psycopg2", 
+                "Packages the cursor in the exception")
+    def test_stmt_exception_pickleable_plus_dbapi(self):
+        raw = testing.db.raw_connection()
+        try:
+            cursor = raw.cursor()
+            cursor.execute("SELECTINCORRECT")
+        except testing.db.dialect.dbapi.DatabaseError, orig:
+            pass
+        finally:
+            raw.close()
+        self._test_stmt_exception_pickleable(orig)
+
+    def _test_stmt_exception_pickleable(self, orig):
+        for sa_exc in (
+            tsa.exc.StatementError("some error", 
+                            "select * from table", 
+                           {"foo":"bar"}, 
+                            orig),
+            tsa.exc.InterfaceError("select * from table", 
+                            {"foo":"bar"}, 
+                            orig),
+        ):
+            for loads, dumps in picklers():
+                repickled = loads(dumps(sa_exc))
+                eq_(repickled.message, sa_exc.message)
+                eq_(repickled.params, {"foo":"bar"})
+                eq_(repickled.statement, sa_exc.statement)
+                if hasattr(sa_exc, "connection_invalidated"):
+                    eq_(repickled.connection_invalidated, 
+                        sa_exc.connection_invalidated)
+                eq_(repickled.orig.message, orig.message)
+
     def test_dont_wrap_mixin(self):
         class MyException(Exception, tsa.exc.DontWrapMixin):
             pass