]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
catch all BaseException in pool and revert failed checkouts
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 12 Dec 2022 18:47:27 +0000 (13:47 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 13 Dec 2022 14:45:22 +0000 (09:45 -0500)
Fixed a long-standing race condition in the connection pool which could
occur under eventlet/gevent monkeypatching schemes in conjunction with the
use of eventlet/gevent ``Timeout`` conditions, where a connection pool
checkout that's interrupted due to the timeout would fail to clean up the
failed state, causing the underlying connection record and sometimes the
database connection itself to "leak", leaving the pool in an invalid state
with unreachable entries. This issue was first identified and fixed in
SQLAlchemy 1.2 for :ticket:`4225`, however the failure modes detected in
that fix failed to accommodate for ``BaseException``, rather than
``Exception``, which prevented eventlet/gevent ``Timeout`` from being
caught. In addition, a block within initial pool connect has also been
identified and hardened with a ``BaseException`` -> "clean failed connect"
block to accommodate for the same condition in this location.
Big thanks to Github user @niklaus for their tenacious efforts in
identifying and describing this intricate issue.

Fixes: #8974
Change-Id: I95a0e1f080d0cee6f1a66977432a586fdf87f686

doc/build/changelog/unreleased_14/8974.rst [new file with mode: 0644]
lib/sqlalchemy/pool/base.py
test/engine/test_pool.py

diff --git a/doc/build/changelog/unreleased_14/8974.rst b/doc/build/changelog/unreleased_14/8974.rst
new file mode 100644 (file)
index 0000000..6400c95
--- /dev/null
@@ -0,0 +1,19 @@
+.. change::
+    :tags: bug, pool
+    :tickets: 8974
+
+    Fixed a long-standing race condition in the connection pool which could
+    occur under eventlet/gevent monkeypatching schemes in conjunction with the
+    use of eventlet/gevent ``Timeout`` conditions, where a connection pool
+    checkout that's interrupted due to the timeout would fail to clean up the
+    failed state, causing the underlying connection record and sometimes the
+    database connection itself to "leak", leaving the pool in an invalid state
+    with unreachable entries. This issue was first identified and fixed in
+    SQLAlchemy 1.2 for :ticket:`4225`, however the failure modes detected in
+    that fix failed to accommodate for ``BaseException``, rather than
+    ``Exception``, which prevented eventlet/gevent ``Timeout`` from being
+    caught. In addition, a block within initial pool connect has also been
+    identified and hardened with a ``BaseException`` -> "clean failed connect"
+    block to accommodate for the same condition in this location.
+    Big thanks to Github user @niklaus for their tenacious efforts in
+    identifying and describing this intricate issue.
index 47c39791c5c2011eedacfdf358c0c558200a08dd..7b211afd9ef70a2c7c4dd3f96d46c22b917ae86e 100644 (file)
@@ -380,10 +380,12 @@ class Pool(log.Identified, event.EventTarget):
                 self._dialect.do_terminate(connection)
             else:
                 self._dialect.do_close(connection)
-        except Exception:
+        except BaseException as e:
             self.logger.error(
                 "Exception closing connection %r", connection, exc_info=True
             )
+            if not isinstance(e, Exception):
+                raise
 
     def _create_connection(self) -> ConnectionPoolEntry:
         """Called by subclasses to create a new ConnectionRecord."""
@@ -714,9 +716,11 @@ class _ConnectionRecord(ConnectionPoolEntry):
 
         try:
             dbapi_connection = rec.get_connection()
-        except Exception as err:
+        except BaseException as err:
             with util.safe_reraise():
                 rec._checkin_failed(err, _fairy_was_created=False)
+
+            # not reached, for code linters only
             raise
 
         echo = pool._should_log_debug()
@@ -738,7 +742,7 @@ class _ConnectionRecord(ConnectionPoolEntry):
         return fairy
 
     def _checkin_failed(
-        self, err: Exception, _fairy_was_created: bool = True
+        self, err: BaseException, _fairy_was_created: bool = True
     ) -> None:
         self.invalidate(e=err)
         self.checkin(
@@ -893,7 +897,7 @@ class _ConnectionRecord(ConnectionPoolEntry):
             self.dbapi_connection = connection = pool._invoke_creator(self)
             pool.logger.debug("Created new connection %r", connection)
             self.fresh = True
-        except Exception as e:
+        except BaseException as e:
             with util.safe_reraise():
                 pool.logger.debug("Error on connect(): %s", e)
         else:
@@ -1271,6 +1275,7 @@ class _ConnectionFairy(PoolProxiedConnection):
         # here.
 
         attempts = 2
+
         while attempts > 0:
             connection_is_fresh = fairy._connection_record.fresh
             fairy._connection_record.fresh = False
@@ -1323,7 +1328,7 @@ class _ConnectionFairy(PoolProxiedConnection):
                     fairy.dbapi_connection = (
                         fairy._connection_record.get_connection()
                     )
-                except Exception as err:
+                except BaseException as err:
                     with util.safe_reraise():
                         fairy._connection_record._checkin_failed(
                             err,
@@ -1341,6 +1346,21 @@ class _ConnectionFairy(PoolProxiedConnection):
                     raise
 
                 attempts -= 1
+            except BaseException as be_outer:
+                with util.safe_reraise():
+                    rec = fairy._connection_record
+                    if rec is not None:
+                        rec._checkin_failed(
+                            be_outer,
+                            _fairy_was_created=True,
+                        )
+
+                    # prevent _ConnectionFairy from being carried
+                    # in the stack trace, see above
+                    del fairy
+
+                # never called, this is for code linters
+                raise
 
         pool.logger.info("Reconnection attempts exhausted on checkout")
         fairy.invalidate()
index f267eac7792894b7ec28fbb0df3fa0145079bdb9..4fddcc871bb0ed8ba30e3cb8402366ec998a359a 100644 (file)
@@ -835,18 +835,34 @@ class PoolEventsTest(PoolTestBase):
         p2.connect()
         eq_(canary, ["listen_one", "listen_two", "listen_one", "listen_three"])
 
-    def test_connect_event_fails_invalidates(self):
+    @testing.variation("exc_type", ["plain", "base_exception"])
+    def test_connect_event_fails_invalidates(self, exc_type):
         fail = False
 
+        if exc_type.plain:
+
+            class RegularThing(Exception):
+                pass
+
+            exc_cls = RegularThing
+        elif exc_type.base_exception:
+
+            class TimeoutThing(BaseException):
+                pass
+
+            exc_cls = TimeoutThing
+        else:
+            exc_type.fail()
+
         def listen_one(conn, rec):
             if fail:
-                raise Exception("it failed")
+                raise exc_cls("it failed")
 
         def listen_two(conn, rec):
             rec.info["important_flag"] = True
 
         p1 = pool.QueuePool(
-            creator=MockDBAPI().connect, pool_size=1, max_overflow=0
+            creator=MockDBAPI().connect, pool_size=1, max_overflow=0, timeout=5
         )
         event.listen(p1, "connect", listen_one)
         event.listen(p1, "connect", listen_two)
@@ -857,7 +873,9 @@ class PoolEventsTest(PoolTestBase):
         conn.close()
 
         fail = True
-        assert_raises(Exception, p1.connect)
+
+        # if the failed checkin is not reverted, the pool is blocked
+        assert_raises(exc_cls, p1.connect)
 
         fail = False
 
@@ -1493,7 +1511,7 @@ class QueuePoolTest(PoolTestBase):
 
         return patch.object(pool, "_finalize_fairy", assert_no_wr_callback)
 
-    def _assert_cleanup_on_pooled_reconnect(self, dbapi, p):
+    def _assert_cleanup_on_pooled_reconnect(self, dbapi, p, exc_cls=Exception):
         # p is QueuePool with size=1, max_overflow=2,
         # and one connection in the pool that will need to
         # reconnect when next used (either due to recycle or invalidate)
@@ -1502,7 +1520,7 @@ class QueuePoolTest(PoolTestBase):
             eq_(p.checkedout(), 0)
             eq_(p._overflow, 0)
             dbapi.shutdown(True)
-            assert_raises_context_ok(Exception, p.connect)
+            assert_raises_context_ok(exc_cls, p.connect)
             eq_(p._overflow, 0)
 
             eq_(p.checkedout(), 0)  # and not 1
@@ -1620,18 +1638,38 @@ class QueuePoolTest(PoolTestBase):
         c = p.connect()
         c.close()
 
-    def test_error_on_pooled_reconnect_cleanup_wcheckout_event(self):
+    @testing.variation("exc_type", ["plain", "base_exception"])
+    def test_error_on_pooled_reconnect_cleanup_wcheckout_event(self, exc_type):
         dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=2)
 
         c1 = p.connect()
         c1.close()
 
-        @event.listens_for(p, "checkout")
-        def handle_checkout_event(dbapi_con, con_record, con_proxy):
-            if dbapi.is_shutdown:
-                raise tsa.exc.DisconnectionError()
+        if exc_type.plain:
 
-        self._assert_cleanup_on_pooled_reconnect(dbapi, p)
+            @event.listens_for(p, "checkout")
+            def handle_checkout_event(dbapi_con, con_record, con_proxy):
+                if dbapi.is_shutdown:
+                    raise tsa.exc.DisconnectionError()
+
+        elif exc_type.base_exception:
+
+            class TimeoutThing(BaseException):
+                pass
+
+            @event.listens_for(p, "checkout")
+            def handle_checkout_event(dbapi_con, con_record, con_proxy):
+                if dbapi.is_shutdown:
+                    raise TimeoutThing()
+
+        else:
+            exc_type.fail()
+
+        self._assert_cleanup_on_pooled_reconnect(
+            dbapi,
+            p,
+            exc_cls=TimeoutThing if exc_type.base_exception else Exception,
+        )
 
     @testing.combinations((True,), (False,))
     def test_userspace_disconnectionerror_weakref_finalizer(self, detach_gced):