]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-37193: remove thread objects which finished process its request (GH-13893)
authorMARUYAMA Norihiro <norihiro.maruyama@gmail.com>
Sun, 1 Nov 2020 23:51:04 +0000 (08:51 +0900)
committerGitHub <noreply@github.com>
Sun, 1 Nov 2020 23:51:04 +0000 (18:51 -0500)
* bpo-37193: remove the thread which finished process request from threads list

* rename variable t to thread.

* don't remove thread from list if it is daemon.

* use lock to protect self._threads.

* use finally block in case of exception from shutdown_request().

* check "not thread.daemon" before lock to avoid holding the lock if it's unnecessary.

* fix the place of _threads_lock.

* separate code to remove a current thread into a function.

* check ValueError when removing thread.

* fix wrong code which all instance shared same lock.

* Extract thread management into a _Threads class to encapsulate atomic operations and separate concerns.

* Replace multiple references of 'block_on_close' with one, avoiding the possibility that 'block_on_close' could change during the course of processing requests. Now, there's exactly one _threads object with behavior fixed for the duration.

* Add docstrings to private classes.

* Add test to ensure that a ThreadingTCPServer can be closed without serving any requests.

* Use _NoThreads as the default value. Fixes AttributeError when server is closed without serving any requests.

* Add blurb

* Add test capturing failure.

Co-authored-by: Jason R. Coombs <jaraco@jaraco.com>
Lib/socketserver.py
Lib/test/test_socketserver.py
Misc/NEWS.d/next/Library/2020-06-12-21-23-20.bpo-37193.wJximU.rst [new file with mode: 0644]

index 57c1ae6e9e8be187f6ed31aea8236761bb3d600f..6859b69682e9720148249f90085ee9cf7e6e29e9 100644 (file)
@@ -128,6 +128,7 @@ import selectors
 import os
 import sys
 import threading
+import contextlib
 from io import BufferedIOBase
 from time import monotonic as time
 
@@ -628,6 +629,55 @@ if hasattr(os, "fork"):
             self.collect_children(blocking=self.block_on_close)
 
 
+class _Threads(list):
+    """
+    Joinable list of all non-daemon threads.
+    """
+    def __init__(self):
+        self._lock = threading.Lock()
+
+    def append(self, thread):
+        if thread.daemon:
+            return
+        with self._lock:
+            super().append(thread)
+
+    def remove(self, thread):
+        with self._lock:
+            # should not happen, but safe to ignore
+            with contextlib.suppress(ValueError):
+                super().remove(thread)
+
+    def remove_current(self):
+        """Remove a current non-daemon thread."""
+        thread = threading.current_thread()
+        if not thread.daemon:
+            self.remove(thread)
+
+    def pop_all(self):
+        with self._lock:
+            self[:], result = [], self[:]
+        return result
+
+    def join(self):
+        for thread in self.pop_all():
+            thread.join()
+
+
+class _NoThreads:
+    """
+    Degenerate version of _Threads.
+    """
+    def append(self, thread):
+        pass
+
+    def join(self):
+        pass
+
+    def remove_current(self):
+        pass
+
+
 class ThreadingMixIn:
     """Mix-in class to handle each request in a new thread."""
 
@@ -636,9 +686,9 @@ class ThreadingMixIn:
     daemon_threads = False
     # If true, server_close() waits until all non-daemonic threads terminate.
     block_on_close = True
-    # For non-daemonic threads, list of threading.Threading objects
+    # Threads object
     # used by server_close() to wait for all threads completion.
-    _threads = None
+    _threads = _NoThreads()
 
     def process_request_thread(self, request, client_address):
         """Same as in BaseServer but as a thread.
@@ -651,27 +701,24 @@ class ThreadingMixIn:
         except Exception:
             self.handle_error(request, client_address)
         finally:
-            self.shutdown_request(request)
+            try:
+                self.shutdown_request(request)
+            finally:
+                self._threads.remove_current()
 
     def process_request(self, request, client_address):
         """Start a new thread to process the request."""
+        if self.block_on_close:
+            vars(self).setdefault('_threads', _Threads())
         t = threading.Thread(target = self.process_request_thread,
                              args = (request, client_address))
         t.daemon = self.daemon_threads
-        if not t.daemon and self.block_on_close:
-            if self._threads is None:
-                self._threads = []
-            self._threads.append(t)
+        self._threads.append(t)
         t.start()
 
     def server_close(self):
         super().server_close()
-        if self.block_on_close:
-            threads = self._threads
-            self._threads = None
-            if threads:
-                for thread in threads:
-                    thread.join()
+        self._threads.join()
 
 
 if hasattr(os, "fork"):
index 7cdd115a951539958cd9517622b05d0a7f583295..1944795f0589468a0ea325beb43aa7df40769810 100644 (file)
@@ -277,6 +277,13 @@ class SocketServerTest(unittest.TestCase):
             t.join()
             s.server_close()
 
+    def test_close_immediately(self):
+        class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
+            pass
+
+        server = MyServer((HOST, 0), lambda: None)
+        server.server_close()
+
     def test_tcpserver_bind_leak(self):
         # Issue #22435: the server socket wouldn't be closed if bind()/listen()
         # failed.
@@ -491,6 +498,23 @@ class MiscTestCase(unittest.TestCase):
         self.assertEqual(server.shutdown_called, 1)
         server.server_close()
 
+    def test_threads_reaped(self):
+        """
+        In #37193, users reported a memory leak
+        due to the saving of every request thread. Ensure that the
+        threads are cleaned up after the requests complete.
+        """
+        class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
+            pass
+
+        server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
+        for n in range(10):
+            with socket.create_connection(server.server_address):
+                server.handle_request()
+        [thread.join() for thread in server._threads]
+        self.assertEqual(len(server._threads), 0)
+        server.server_close()
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2020-06-12-21-23-20.bpo-37193.wJximU.rst b/Misc/NEWS.d/next/Library/2020-06-12-21-23-20.bpo-37193.wJximU.rst
new file mode 100644 (file)
index 0000000..fbf56d3
--- /dev/null
@@ -0,0 +1,2 @@
+Fixed memory leak in ``socketserver.ThreadingMixIn`` introduced in Python
+3.7.