]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Misc asyncio improvements from upstream
authorGuido van Rossum <guido@python.org>
Fri, 30 Sep 2016 15:17:15 +0000 (08:17 -0700)
committerGuido van Rossum <guido@python.org>
Fri, 30 Sep 2016 15:17:15 +0000 (08:17 -0700)
Lib/asyncio/base_events.py
Lib/asyncio/base_subprocess.py
Lib/asyncio/coroutines.py
Lib/asyncio/queues.py
Lib/asyncio/tasks.py
Lib/test/test_asyncio/test_base_events.py
Lib/test/test_asyncio/test_selector_events.py
Lib/test/test_asyncio/test_tasks.py

index 03935ea94ba53018ac7e624384d6fa1805c8ab85..af66c0a5b6fe9bb57626456df438f1357d9057fd 100644 (file)
@@ -115,24 +115,16 @@ def _ipaddr_info(host, port, family, type, proto):
 
     if port is None:
         port = 0
-    elif isinstance(port, bytes):
-        if port == b'':
-            port = 0
-        else:
-            try:
-                port = int(port)
-            except ValueError:
-                # Might be a service name like b"http".
-                port = socket.getservbyname(port.decode('ascii'))
-    elif isinstance(port, str):
-        if port == '':
-            port = 0
-        else:
-            try:
-                port = int(port)
-            except ValueError:
-                # Might be a service name like "http".
-                port = socket.getservbyname(port)
+    elif isinstance(port, bytes) and port == b'':
+        port = 0
+    elif isinstance(port, str) and port == '':
+        port = 0
+    else:
+        # If port's a service name like "http", don't skip getaddrinfo.
+        try:
+            port = int(port)
+        except (TypeError, ValueError):
+            return None
 
     if family == socket.AF_UNSPEC:
         afs = [socket.AF_INET, socket.AF_INET6]
index bcc481d20ea54a385d0ff56dbeb9fa7a83a78334..23742a169a473056fd0e26163a36df993edc8b80 100644 (file)
@@ -3,7 +3,6 @@ import subprocess
 import warnings
 
 from . import compat
-from . import futures
 from . import protocols
 from . import transports
 from .coroutines import coroutine
index e013d64edfcb59e1572c6b73dde6baa7078e3adb..5cecc762df97995d8de9a9467480eb2a903c030c 100644 (file)
@@ -120,8 +120,8 @@ class CoroWrapper:
         def send(self, value):
             return self.gen.send(value)
 
-    def throw(self, exc):
-        return self.gen.throw(exc)
+    def throw(self, type, value=None, traceback=None):
+        return self.gen.throw(type, value, traceback)
 
     def close(self):
         return self.gen.close()
index c453f02d8cf89973b39d6a46f332eeead621e4dc..2d38972c0de29d70f88c651ae3cf7e29f2f3249e 100644 (file)
@@ -7,7 +7,6 @@ import heapq
 
 from . import compat
 from . import events
-from . import futures
 from . import locks
 from .coroutines import coroutine
 
index 4c66546428b9dd92ffa3fd810ef7bb080bd2dfea..f735b44dc015a342bde3df51b86a2beafd20ff86 100644 (file)
@@ -594,6 +594,10 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False):
     """Return a future aggregating results from the given coroutines
     or futures.
 
+    Coroutines will be wrapped in a future and scheduled in the event
+    loop. They will not necessarily be scheduled in the same order as
+    passed in.
+
     All futures must share the same event loop.  If all the tasks are
     done successfully, the returned future's result is the list of
     results (in the order of the original sequence, not necessarily
index 43ebdc8b2cba222a7b45bdd29c882fe9ac41e2cf..e86b74e61a3b508c7fb3ac5478968dd9beccfc8c 100644 (file)
@@ -142,26 +142,6 @@ class BaseEventTests(test_utils.TestCase):
             (INET, STREAM, TCP, '', ('1.2.3.4', 1)),
             base_events._ipaddr_info('1.2.3.4', b'1', INET, STREAM, TCP))
 
-    def test_getaddrinfo_servname(self):
-        INET = socket.AF_INET
-        STREAM = socket.SOCK_STREAM
-        TCP = socket.IPPROTO_TCP
-
-        self.assertEqual(
-            (INET, STREAM, TCP, '', ('1.2.3.4', 80)),
-            base_events._ipaddr_info('1.2.3.4', 'http', INET, STREAM, TCP))
-
-        self.assertEqual(
-            (INET, STREAM, TCP, '', ('1.2.3.4', 80)),
-            base_events._ipaddr_info('1.2.3.4', b'http', INET, STREAM, TCP))
-
-        # Raises "service/proto not found".
-        with self.assertRaises(OSError):
-            base_events._ipaddr_info('1.2.3.4', 'nonsense', INET, STREAM, TCP)
-
-        with self.assertRaises(OSError):
-            base_events._ipaddr_info('1.2.3.4', 'nonsense', INET, STREAM, TCP)
-
     @patch_socket
     def test_ipaddr_info_no_inet_pton(self, m_socket):
         del m_socket.inet_pton
@@ -1209,6 +1189,37 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
     def test_create_connection_no_inet_pton(self, m_socket):
         self._test_create_connection_ip_addr(m_socket, False)
 
+    @patch_socket
+    def test_create_connection_service_name(self, m_socket):
+        m_socket.getaddrinfo = socket.getaddrinfo
+        sock = m_socket.socket.return_value
+
+        self.loop.add_reader = mock.Mock()
+        self.loop.add_reader._is_coroutine = False
+        self.loop.add_writer = mock.Mock()
+        self.loop.add_writer._is_coroutine = False
+
+        for service, port in ('http', 80), (b'http', 80):
+            coro = self.loop.create_connection(asyncio.Protocol,
+                                               '127.0.0.1', service)
+
+            t, p = self.loop.run_until_complete(coro)
+            try:
+                sock.connect.assert_called_with(('127.0.0.1', port))
+                _, kwargs = m_socket.socket.call_args
+                self.assertEqual(kwargs['family'], m_socket.AF_INET)
+                self.assertEqual(kwargs['type'], m_socket.SOCK_STREAM)
+            finally:
+                t.close()
+                test_utils.run_briefly(self.loop)  # allow transport to close
+
+        for service in 'nonsense', b'nonsense':
+            coro = self.loop.create_connection(asyncio.Protocol,
+                                               '127.0.0.1', service)
+
+            with self.assertRaises(OSError):
+                self.loop.run_until_complete(coro)
+
     def test_create_connection_no_local_addr(self):
         @asyncio.coroutine
         def getaddrinfo(host, *args, **kw):
index 8b621bf22016b4533905996e05c31d07ff46c206..0c26a87dcdfc532579a31dc655375e738960579f 100644 (file)
@@ -2,6 +2,8 @@
 
 import errno
 import socket
+import threading
+import time
 import unittest
 from unittest import mock
 try:
@@ -1784,5 +1786,89 @@ class SelectorDatagramTransportTests(test_utils.TestCase):
                 'Fatal error on transport\nprotocol:.*\ntransport:.*'),
             exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY))
 
+
+class SelectorLoopFunctionalTests(unittest.TestCase):
+
+    def setUp(self):
+        self.loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(None)
+
+    def tearDown(self):
+        self.loop.close()
+
+    @asyncio.coroutine
+    def recv_all(self, sock, nbytes):
+        buf = b''
+        while len(buf) < nbytes:
+            buf += yield from self.loop.sock_recv(sock, nbytes - len(buf))
+        return buf
+
+    def test_sock_connect_sock_write_race(self):
+        TIMEOUT = 3.0
+        PAYLOAD = b'DATA' * 1024 * 1024
+
+        class Server(threading.Thread):
+            def __init__(self, *args, srv_sock, **kwargs):
+                super().__init__(*args, **kwargs)
+                self.srv_sock = srv_sock
+
+            def run(self):
+                with self.srv_sock:
+                    srv_sock.listen(100)
+
+                    sock, addr = self.srv_sock.accept()
+                    sock.settimeout(TIMEOUT)
+
+                    with sock:
+                        sock.sendall(b'helo')
+
+                        buf = bytearray()
+                        while len(buf) < len(PAYLOAD):
+                            pack = sock.recv(1024 * 65)
+                            if not pack:
+                                break
+                            buf.extend(pack)
+
+        @asyncio.coroutine
+        def client(addr):
+            sock = socket.socket()
+            with sock:
+                sock.setblocking(False)
+
+                started = time.monotonic()
+                while True:
+                    if time.monotonic() - started > TIMEOUT:
+                        self.fail('unable to connect to the socket')
+                        return
+                    try:
+                        yield from self.loop.sock_connect(sock, addr)
+                    except OSError:
+                        yield from asyncio.sleep(0.05, loop=self.loop)
+                    else:
+                        break
+
+                # Give 'Server' thread a chance to accept and send b'helo'
+                time.sleep(0.1)
+
+                data = yield from self.recv_all(sock, 4)
+                self.assertEqual(data, b'helo')
+                yield from self.loop.sock_sendall(sock, PAYLOAD)
+
+        srv_sock = socket.socket()
+        srv_sock.settimeout(TIMEOUT)
+        srv_sock.bind(('127.0.0.1', 0))
+        srv_addr = srv_sock.getsockname()
+
+        srv = Server(srv_sock=srv_sock, daemon=True)
+        srv.start()
+
+        try:
+            self.loop.run_until_complete(
+                asyncio.wait_for(client(srv_addr), loop=self.loop,
+                                 timeout=TIMEOUT))
+        finally:
+            srv.join()
+
+
 if __name__ == '__main__':
     unittest.main()
index e7fb774fcae14f9ff9319a683ac04a3545412055..2863c423b5190ee8bcb905db4d634c627bd3d040 100644 (file)
@@ -1723,6 +1723,37 @@ class TaskTests(test_utils.TestCase):
         wd['cw'] = cw  # Would fail without __weakref__ slot.
         cw.gen = None  # Suppress warning from __del__.
 
+    def test_corowrapper_throw(self):
+        # Issue 429: CoroWrapper.throw must be compatible with gen.throw
+        def foo():
+            value = None
+            while True:
+                try:
+                    value = yield value
+                except Exception as e:
+                    value = e
+
+        exception = Exception("foo")
+        cw = asyncio.coroutines.CoroWrapper(foo())
+        cw.send(None)
+        self.assertIs(exception, cw.throw(exception))
+
+        cw = asyncio.coroutines.CoroWrapper(foo())
+        cw.send(None)
+        self.assertIs(exception, cw.throw(Exception, exception))
+
+        cw = asyncio.coroutines.CoroWrapper(foo())
+        cw.send(None)
+        exception = cw.throw(Exception, "foo")
+        self.assertIsInstance(exception, Exception)
+        self.assertEqual(exception.args, ("foo", ))
+
+        cw = asyncio.coroutines.CoroWrapper(foo())
+        cw.send(None)
+        exception = cw.throw(Exception, "foo", None)
+        self.assertIsInstance(exception, Exception)
+        self.assertEqual(exception.args, ("foo", ))
+
     @unittest.skipUnless(PY34,
                          'need python 3.4 or later')
     def test_log_destroyed_pending_task(self):