]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure all SQLAlchemy exception can be properly pickled
authorFederico Caselli <cfederico87@gmail.com>
Mon, 27 Sep 2021 19:40:47 +0000 (15:40 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 1 Oct 2021 21:22:57 +0000 (17:22 -0400)
Implemented proper ``__reduce__()`` methods for all SQLAlchemy exception
objects to ensure they all support clean round trips when pickling, as
exception objects are often serialized for the purposes of various
debugging tools.

Fixes  #7077
Closes: #7078
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/7078
Pull-request-sha: 8ba69f26532f0f60f679289702c9477e25149bf8

Change-Id: Id62f8d351cd9180c441ffa9201efcf5f1876bf83

doc/build/changelog/unreleased_14/7077.rst [new file with mode: 0644]
lib/sqlalchemy/exc.py
test/base/test_except.py

diff --git a/doc/build/changelog/unreleased_14/7077.rst b/doc/build/changelog/unreleased_14/7077.rst
new file mode 100644 (file)
index 0000000..305c704
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, engine
+    :tickets: 7077
+
+    Implemented proper ``__reduce__()`` methods for all SQLAlchemy exception
+    objects to ensure they all support clean round trips when pickling, as
+    exception objects are often serialized for the purposes of various
+    debugging tools.
index a24cf7ba49314c4746823934b391eadf8ebbb44c..afda1b7795ddfb30ff868e91075d7f91d50590c9 100644 (file)
@@ -118,6 +118,10 @@ class ObjectNotExecutableError(ArgumentError):
         super(ObjectNotExecutableError, self).__init__(
             "Not an executable object: %r" % target
         )
+        self.target = target
+
+    def __reduce__(self):
+        return self.__class__, (self.target,)
 
 
 class NoSuchModuleError(ArgumentError):
@@ -164,7 +168,11 @@ class CircularDependencyError(SQLAlchemyError):
         self.edges = edges
 
     def __reduce__(self):
-        return self.__class__, (None, self.cycles, self.edges, self.args[0])
+        return (
+            self.__class__,
+            (None, self.cycles, self.edges, self.args[0]),
+            {"code": self.code} if self.code is not None else {},
+        )
 
 
 class CompileError(SQLAlchemyError):
@@ -188,6 +196,12 @@ class UnsupportedCompilationError(CompileError):
             "Compiler %r can't render element of type %s%s"
             % (compiler, element_type, ": %s" % message if message else "")
         )
+        self.compiler = compiler
+        self.element_type = element_type
+        self.message = message
+
+    def __reduce__(self):
+        return self.__class__, (self.compiler, self.element_type, self.message)
 
 
 class IdentifierError(SQLAlchemyError):
@@ -258,7 +272,7 @@ class ResourceClosedError(InvalidRequestError):
     object that's in a closed state."""
 
 
-class NoSuchColumnError(KeyError, InvalidRequestError):
+class NoSuchColumnError(InvalidRequestError, KeyError):
     """A nonexistent column is requested from a ``Row``."""
 
 
@@ -431,8 +445,10 @@ class StatementError(SQLAlchemyError):
                 self.params,
                 self.orig,
                 self.hide_parameters,
+                self.__dict__.get("code"),
                 self.ismulti,
             ),
+            {"detail": self.detail},
         )
 
     @_preloaded.preload_module("sqlalchemy.sql.util")
@@ -571,8 +587,10 @@ class DBAPIError(StatementError):
                 self.orig,
                 self.hide_parameters,
                 self.connection_invalidated,
+                self.__dict__.get("code"),
                 self.ismulti,
             ),
+            {"detail": self.detail},
         )
 
     def __init__(
index 94dc8520eb41e83252cb2c0e7d776e08a48cd538..be6f448bd84fec5b4cb57a4ce95ca2609ad1ccf7 100644 (file)
@@ -2,9 +2,12 @@
 
 """Tests exceptions and DB-API exception wrapping."""
 
+from itertools import product
+import pickle
 
 from sqlalchemy import exc as sa_exceptions
 from sqlalchemy.engine import default
+from sqlalchemy.testing import combinations_list
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.util import compat
@@ -414,3 +417,133 @@ class WrapTest(fixtures.TestBase):
             self.assert_(False)
         except SystemExit:
             self.assert_(True)
+
+
+def details(cls):
+    inst = cls("msg", "stmt", (), "orig")
+    inst.add_detail("d1")
+    inst.add_detail("d2")
+    return inst
+
+
+ALL_EXC = [
+    (
+        [sa_exceptions.SQLAlchemyError],
+        [lambda cls: cls(1, 2, code="42")],
+    ),
+    ([sa_exceptions.ObjectNotExecutableError], [lambda cls: cls("xx")]),
+    (
+        [
+            sa_exceptions.ArgumentError,
+            sa_exceptions.NoSuchModuleError,
+            sa_exceptions.NoForeignKeysError,
+            sa_exceptions.AmbiguousForeignKeysError,
+            sa_exceptions.CompileError,
+            sa_exceptions.IdentifierError,
+            sa_exceptions.DisconnectionError,
+            sa_exceptions.InvalidatePoolError,
+            sa_exceptions.TimeoutError,
+            sa_exceptions.InvalidRequestError,
+            sa_exceptions.NoInspectionAvailable,
+            sa_exceptions.PendingRollbackError,
+            sa_exceptions.ResourceClosedError,
+            sa_exceptions.NoSuchColumnError,
+            sa_exceptions.NoResultFound,
+            sa_exceptions.MultipleResultsFound,
+            sa_exceptions.NoReferenceError,
+            sa_exceptions.AwaitRequired,
+            sa_exceptions.MissingGreenlet,
+            sa_exceptions.NoSuchTableError,
+            sa_exceptions.UnreflectableTableError,
+            sa_exceptions.UnboundExecutionError,
+        ],
+        [lambda cls: cls("foo", code="42")],
+    ),
+    (
+        [sa_exceptions.CircularDependencyError],
+        [
+            lambda cls: cls("msg", ["cycles"], "edges"),
+            lambda cls: cls("msg", ["cycles"], "edges", "xx", "zz"),
+        ],
+    ),
+    (
+        [sa_exceptions.UnsupportedCompilationError],
+        [lambda cls: cls("cmp", "el"), lambda cls: cls("cmp", "el", "msg")],
+    ),
+    (
+        [sa_exceptions.NoReferencedTableError],
+        [lambda cls: cls("msg", "tbl")],
+    ),
+    (
+        [sa_exceptions.NoReferencedColumnError],
+        [lambda cls: cls("msg", "tbl", "col")],
+    ),
+    (
+        [sa_exceptions.StatementError],
+        [
+            lambda cls: cls("msg", "stmt", (), "orig"),
+            lambda cls: cls("msg", "stmt", (), "orig", True, "99", True),
+            details,
+        ],
+    ),
+    (
+        [
+            sa_exceptions.DBAPIError,
+            sa_exceptions.InterfaceError,
+            sa_exceptions.DatabaseError,
+            sa_exceptions.DataError,
+            sa_exceptions.OperationalError,
+            sa_exceptions.IntegrityError,
+            sa_exceptions.InternalError,
+            sa_exceptions.ProgrammingError,
+            sa_exceptions.NotSupportedError,
+        ],
+        [
+            lambda cls: cls("stmt", (), "orig"),
+            lambda cls: cls("stmt", (), "orig", True, True, "99", True),
+            details,
+        ],
+    ),
+    (
+        [
+            sa_exceptions.SADeprecationWarning,
+            sa_exceptions.RemovedIn20Warning,
+            sa_exceptions.MovedIn20Warning,
+            sa_exceptions.SAWarning,
+        ],
+        [lambda cls: cls("foo", code="42")],
+    ),
+    ([sa_exceptions.SAPendingDeprecationWarning], [lambda cls: cls(1, 2, 3)]),
+]
+
+
+class PickleException(fixtures.TestBase):
+    def test_all_exc(self):
+        found = {
+            e
+            for e in vars(sa_exceptions).values()
+            if isinstance(e, type) and issubclass(e, Exception)
+        }
+
+        listed = set()
+        for cls_list, _ in ALL_EXC:
+            listed.update(cls_list)
+
+        eq_(found, listed)
+
+    def make_combinations():
+        unroll = []
+        for cls_list, callable_list in ALL_EXC:
+            unroll.extend(product(cls_list, callable_list))
+
+        print(unroll)
+        return combinations_list(unroll)
+
+    @make_combinations()
+    def test_exc(self, cls, ctor):
+        inst = ctor(cls)
+        re_created = pickle.loads(pickle.dumps(inst))
+
+        eq_(re_created.__class__, cls)
+        eq_(re_created.args, inst.args)
+        eq_(re_created.__dict__, inst.__dict__)