]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-34745: Fix asyncio sslproto memory issues (GH-12386)
authorMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>
Sun, 17 Mar 2019 23:09:14 +0000 (16:09 -0700)
committerGitHub <noreply@github.com>
Sun, 17 Mar 2019 23:09:14 +0000 (16:09 -0700)
* Fix handshake timeout leak in asyncio/sslproto

Refs MagicStack/uvloopGH-222

* Break circular ref _SSLPipe <-> SSLProtocol

* bpo-34745: Fix asyncio ssl memory leak

* Break circular ref SSLProtocol <-> UserProtocol

* Add NEWS entry
(cherry picked from commit f683f464259715d620777d7ed568e701337a703a)

Co-authored-by: Fantix King <fantix.king@gmail.com>
Lib/asyncio/sslproto.py
Lib/test/test_asyncio/test_sslproto.py
Misc/NEWS.d/next/Library/2019-03-17-16-43-29.bpo-34745.nOfm7_.rst [new file with mode: 0644]

index 12fdb0d1c5ec1008844a07cd682678d96cce6e3d..15eb9c78443af7f81ff0a8305327572ac6749a94 100644 (file)
@@ -499,7 +499,11 @@ class SSLProtocol(protocols.Protocol):
                 self._app_transport._closed = True
         self._transport = None
         self._app_transport = None
+        if getattr(self, '_handshake_timeout_handle', None):
+            self._handshake_timeout_handle.cancel()
         self._wakeup_waiter(exc)
+        self._app_protocol = None
+        self._sslpipe = None
 
     def pause_writing(self):
         """Called when the low-level transport's buffer goes over
index 44ae3b67e0bdda166763ebfc27be151602ccb764..6c0b0e2a0187eecb9be5cad98cf92aca72a79bf2 100644 (file)
@@ -4,6 +4,7 @@ import logging
 import socket
 import sys
 import unittest
+import weakref
 from unittest import mock
 try:
     import ssl
@@ -270,6 +271,72 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
             self.loop.run_until_complete(
                 asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
 
+        # No garbage is left if SSL is closed uncleanly
+        client_context = weakref.ref(client_context)
+        self.assertIsNone(client_context())
+
+    def test_create_connection_memory_leak(self):
+        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
+
+        server_context = test_utils.simple_server_sslcontext()
+        client_context = test_utils.simple_client_sslcontext()
+
+        def serve(sock):
+            sock.settimeout(self.TIMEOUT)
+
+            sock.start_tls(server_context, server_side=True)
+
+            sock.sendall(b'O')
+            data = sock.recv_all(len(HELLO_MSG))
+            self.assertEqual(len(data), len(HELLO_MSG))
+
+            sock.shutdown(socket.SHUT_RDWR)
+            sock.close()
+
+        class ClientProto(asyncio.Protocol):
+            def __init__(self, on_data, on_eof):
+                self.on_data = on_data
+                self.on_eof = on_eof
+                self.con_made_cnt = 0
+
+            def connection_made(proto, tr):
+                # XXX: We assume user stores the transport in protocol
+                proto.tr = tr
+                proto.con_made_cnt += 1
+                # Ensure connection_made gets called only once.
+                self.assertEqual(proto.con_made_cnt, 1)
+
+            def data_received(self, data):
+                self.on_data.set_result(data)
+
+            def eof_received(self):
+                self.on_eof.set_result(True)
+
+        async def client(addr):
+            await asyncio.sleep(0.5)
+
+            on_data = self.loop.create_future()
+            on_eof = self.loop.create_future()
+
+            tr, proto = await self.loop.create_connection(
+                lambda: ClientProto(on_data, on_eof), *addr,
+                ssl=client_context)
+
+            self.assertEqual(await on_data, b'O')
+            tr.write(HELLO_MSG)
+            await on_eof
+
+            tr.close()
+
+        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
+            self.loop.run_until_complete(
+                asyncio.wait_for(client(srv.addr), timeout=10))
+
+        # No garbage is left for SSL client from loop.create_connection, even
+        # if user stores the SSLTransport in corresponding protocol instance
+        client_context = weakref.ref(client_context)
+        self.assertIsNone(client_context())
+
     def test_start_tls_client_buf_proto_1(self):
         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 
@@ -560,6 +627,11 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
         # exception or log an error, even if the handshake failed
         self.assertEqual(messages, [])
 
+        # The 10s handshake timeout should be cancelled to free related
+        # objects without really waiting for 10s
+        client_sslctx = weakref.ref(client_sslctx)
+        self.assertIsNone(client_sslctx())
+
     def test_create_connection_ssl_slow_handshake(self):
         client_sslctx = test_utils.simple_client_sslcontext()
 
diff --git a/Misc/NEWS.d/next/Library/2019-03-17-16-43-29.bpo-34745.nOfm7_.rst b/Misc/NEWS.d/next/Library/2019-03-17-16-43-29.bpo-34745.nOfm7_.rst
new file mode 100644 (file)
index 0000000..d88f36a
--- /dev/null
@@ -0,0 +1 @@
+Fix :mod:`asyncio` ssl memory issues caused by circular references