]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-34638: Store a weak reference to stream reader to break strong references loop...
authorAndrew Svetlov <andrew.svetlov@gmail.com>
Wed, 12 Sep 2018 18:43:04 +0000 (11:43 -0700)
committerGitHub <noreply@github.com>
Wed, 12 Sep 2018 18:43:04 +0000 (11:43 -0700)
Store a weak reference to stream readerfor breaking strong references

It breaks the strong reference loop between reader and protocol and allows to detect and close the socket if the stream is deleted (garbage collected)

Lib/asyncio/streams.py
Lib/asyncio/subprocess.py
Lib/test/test_asyncio/test_streams.py
Misc/NEWS.d/next/Library/2018-09-12-10-33-44.bpo-34638.xaeZX5.rst [new file with mode: 0644]

index 9dab49b35e46e8acff0ccfc750d1e4fe7e986264..e7fb22ee5d1ae9d67d23bf9fc75e0cf3a0a572b4 100644 (file)
@@ -3,6 +3,8 @@ __all__ = (
     'open_connection', 'start_server')
 
 import socket
+import sys
+import weakref
 
 if hasattr(socket, 'AF_UNIX'):
     __all__ += ('open_unix_connection', 'start_unix_server')
@@ -10,6 +12,7 @@ if hasattr(socket, 'AF_UNIX'):
 from . import coroutines
 from . import events
 from . import exceptions
+from . import format_helpers
 from . import protocols
 from .log import logger
 from .tasks import sleep
@@ -186,46 +189,106 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
     call inappropriate methods of the protocol.)
     """
 
+    _source_traceback = None
+
     def __init__(self, stream_reader, client_connected_cb=None, loop=None):
         super().__init__(loop=loop)
-        self._stream_reader = stream_reader
+        if stream_reader is not None:
+            self._stream_reader_wr = weakref.ref(stream_reader,
+                                                 self._on_reader_gc)
+            self._source_traceback = stream_reader._source_traceback
+        else:
+            self._stream_reader_wr = None
+        if client_connected_cb is not None:
+            # This is a stream created by the `create_server()` function.
+            # Keep a strong reference to the reader until a connection
+            # is established.
+            self._strong_reader = stream_reader
+        self._reject_connection = False
         self._stream_writer = None
+        self._transport = None
         self._client_connected_cb = client_connected_cb
         self._over_ssl = False
         self._closed = self._loop.create_future()
 
+    def _on_reader_gc(self, wr):
+        transport = self._transport
+        if transport is not None:
+            # connection_made was called
+            context = {
+                'message': ('An open stream object is being garbage '
+                            'collected; call "stream.close()" explicitly.')
+            }
+            if self._source_traceback:
+                context['source_traceback'] = self._source_traceback
+            self._loop.call_exception_handler(context)
+            transport.abort()
+        else:
+            self._reject_connection = True
+        self._stream_reader_wr = None
+
+    def _untrack_reader(self):
+        self._stream_reader_wr = None
+
+    @property
+    def _stream_reader(self):
+        if self._stream_reader_wr is None:
+            return None
+        return self._stream_reader_wr()
+
     def connection_made(self, transport):
-        self._stream_reader.set_transport(transport)
+        if self._reject_connection:
+            context = {
+                'message': ('An open stream was garbage collected prior to '
+                            'establishing network connection; '
+                            'call "stream.close()" explicitly.')
+            }
+            if self._source_traceback:
+                context['source_traceback'] = self._source_traceback
+            self._loop.call_exception_handler(context)
+            transport.abort()
+            return
+        self._transport = transport
+        reader = self._stream_reader
+        if reader is not None:
+            reader.set_transport(transport)
         self._over_ssl = transport.get_extra_info('sslcontext') is not None
         if self._client_connected_cb is not None:
             self._stream_writer = StreamWriter(transport, self,
-                                               self._stream_reader,
+                                               reader,
                                                self._loop)
-            res = self._client_connected_cb(self._stream_reader,
+            res = self._client_connected_cb(reader,
                                             self._stream_writer)
             if coroutines.iscoroutine(res):
                 self._loop.create_task(res)
+            self._strong_reader = None
 
     def connection_lost(self, exc):
-        if self._stream_reader is not None:
+        reader = self._stream_reader
+        if reader is not None:
             if exc is None:
-                self._stream_reader.feed_eof()
+                reader.feed_eof()
             else:
-                self._stream_reader.set_exception(exc)
+                reader.set_exception(exc)
         if not self._closed.done():
             if exc is None:
                 self._closed.set_result(None)
             else:
                 self._closed.set_exception(exc)
         super().connection_lost(exc)
-        self._stream_reader = None
+        self._stream_reader_wr = None
         self._stream_writer = None
+        self._transport = None
 
     def data_received(self, data):
-        self._stream_reader.feed_data(data)
+        reader = self._stream_reader
+        if reader is not None:
+            reader.feed_data(data)
 
     def eof_received(self):
-        self._stream_reader.feed_eof()
+        reader = self._stream_reader
+        if reader is not None:
+            reader.feed_eof()
         if self._over_ssl:
             # Prevent a warning in SSLProtocol.eof_received:
             # "returning true from eof_received()
@@ -282,6 +345,9 @@ class StreamWriter:
         return self._transport.can_write_eof()
 
     def close(self):
+        # a reader can be garbage collected
+        # after connection closing
+        self._protocol._untrack_reader()
         return self._transport.close()
 
     def is_closing(self):
@@ -318,6 +384,8 @@ class StreamWriter:
 
 class StreamReader:
 
+    _source_traceback = None
+
     def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
         # The line length limit is  a security feature;
         # it also doubles as half the buffer limit.
@@ -336,6 +404,9 @@ class StreamReader:
         self._exception = None
         self._transport = None
         self._paused = False
+        if self._loop.get_debug():
+            self._source_traceback = format_helpers.extract_stack(
+                sys._getframe(1))
 
     def __repr__(self):
         info = ['StreamReader']
index 90fc00de8339fb2c70cb19c5782e1a7cdc0d8a01..c86de3d087024061dd87c3ab9f824cdf82ef2b0d 100644 (file)
@@ -36,6 +36,11 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
             info.append(f'stderr={self.stderr!r}')
         return '<{}>'.format(' '.join(info))
 
+    def _untrack_reader(self):
+        # StreamWriter.close() expects the protocol
+        # to have this method defined.
+        pass
+
     def connection_made(self, transport):
         self._transport = transport
 
index 66d18738b316269ff282a3781afbfb17e5f468bb..67ac9d91a0b1c52bf745de63d147d50c95d18cb1 100644 (file)
@@ -46,6 +46,8 @@ class StreamTests(test_utils.TestCase):
         self.assertIs(stream._loop, m_events.get_event_loop.return_value)
 
     def _basetest_open_connection(self, open_connection_fut):
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
         reader, writer = self.loop.run_until_complete(open_connection_fut)
         writer.write(b'GET / HTTP/1.0\r\n\r\n')
         f = reader.readline()
@@ -55,6 +57,7 @@ class StreamTests(test_utils.TestCase):
         data = self.loop.run_until_complete(f)
         self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
         writer.close()
+        self.assertEqual(messages, [])
 
     def test_open_connection(self):
         with test_utils.run_test_server() as httpd:
@@ -70,6 +73,8 @@ class StreamTests(test_utils.TestCase):
             self._basetest_open_connection(conn_fut)
 
     def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
         try:
             reader, writer = self.loop.run_until_complete(open_connection_fut)
         finally:
@@ -80,6 +85,7 @@ class StreamTests(test_utils.TestCase):
         self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
 
         writer.close()
+        self.assertEqual(messages, [])
 
     @unittest.skipIf(ssl is None, 'No ssl module')
     def test_open_connection_no_loop_ssl(self):
@@ -104,6 +110,8 @@ class StreamTests(test_utils.TestCase):
             self._basetest_open_connection_no_loop_ssl(conn_fut)
 
     def _basetest_open_connection_error(self, open_connection_fut):
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
         reader, writer = self.loop.run_until_complete(open_connection_fut)
         writer._protocol.connection_lost(ZeroDivisionError())
         f = reader.read()
@@ -111,6 +119,7 @@ class StreamTests(test_utils.TestCase):
             self.loop.run_until_complete(f)
         writer.close()
         test_utils.run_briefly(self.loop)
+        self.assertEqual(messages, [])
 
     def test_open_connection_error(self):
         with test_utils.run_test_server() as httpd:
@@ -621,6 +630,9 @@ class StreamTests(test_utils.TestCase):
             writer.close()
             return msgback
 
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
         # test the server variant with a coroutine as client handler
         server = MyServer(self.loop)
         addr = server.start()
@@ -637,6 +649,8 @@ class StreamTests(test_utils.TestCase):
         server.stop()
         self.assertEqual(msg, b"hello world!\n")
 
+        self.assertEqual(messages, [])
+
     @support.skip_unless_bind_unix_socket
     def test_start_unix_server(self):
 
@@ -685,6 +699,9 @@ class StreamTests(test_utils.TestCase):
             writer.close()
             return msgback
 
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
         # test the server variant with a coroutine as client handler
         with test_utils.unix_socket_path() as path:
             server = MyServer(self.loop, path)
@@ -703,6 +720,8 @@ class StreamTests(test_utils.TestCase):
             server.stop()
             self.assertEqual(msg, b"hello world!\n")
 
+        self.assertEqual(messages, [])
+
     @unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
     def test_read_all_from_pipe_reader(self):
         # See asyncio issue 168.  This test is derived from the example
@@ -893,6 +912,58 @@ os.close(fd)
             wr.close()
             self.loop.run_until_complete(wr.wait_closed())
 
+    def test_del_stream_before_sock_closing(self):
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+        with test_utils.run_test_server() as httpd:
+            rd, wr = self.loop.run_until_complete(
+                asyncio.open_connection(*httpd.address, loop=self.loop))
+            sock = wr.get_extra_info('socket')
+            self.assertNotEqual(sock.fileno(), -1)
+
+            wr.write(b'GET / HTTP/1.0\r\n\r\n')
+            f = rd.readline()
+            data = self.loop.run_until_complete(f)
+            self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+
+            # drop refs to reader/writer
+            del rd
+            del wr
+            gc.collect()
+            # make a chance to close the socket
+            test_utils.run_briefly(self.loop)
+
+            self.assertEqual(1, len(messages))
+            self.assertEqual(sock.fileno(), -1)
+
+        self.assertEqual(1, len(messages))
+        self.assertEqual('An open stream object is being garbage '
+                         'collected; call "stream.close()" explicitly.',
+                         messages[0]['message'])
+
+    def test_del_stream_before_connection_made(self):
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+        with test_utils.run_test_server() as httpd:
+            rd = asyncio.StreamReader(loop=self.loop)
+            pr = asyncio.StreamReaderProtocol(rd, loop=self.loop)
+            del rd
+            gc.collect()
+            tr, _ = self.loop.run_until_complete(
+                self.loop.create_connection(
+                    lambda: pr, *httpd.address))
+
+            sock = tr.get_extra_info('socket')
+            self.assertEqual(sock.fileno(), -1)
+
+        self.assertEqual(1, len(messages))
+        self.assertEqual('An open stream was garbage collected prior to '
+                         'establishing network connection; '
+                         'call "stream.close()" explicitly.',
+                         messages[0]['message'])
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2018-09-12-10-33-44.bpo-34638.xaeZX5.rst b/Misc/NEWS.d/next/Library/2018-09-12-10-33-44.bpo-34638.xaeZX5.rst
new file mode 100644 (file)
index 0000000..13b3952
--- /dev/null
@@ -0,0 +1,3 @@
+Store a weak reference to stream reader to break strong references loop
+between reader and protocol.  It allows to detect and close the socket if
+the stream is deleted (garbage collected) without ``close()`` call.