]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
asyncio: Only allow Unix Stream sockets for loop.create_unix_server/connection
authorYury Selivanov <yury@magic.io>
Fri, 7 Oct 2016 16:39:57 +0000 (12:39 -0400)
committerYury Selivanov <yury@magic.io>
Fri, 7 Oct 2016 16:39:57 +0000 (12:39 -0400)
Lib/asyncio/unix_events.py
Lib/test/test_asyncio/test_unix_events.py

index e1f5c52be00cb08858d8610426320c856b60f521..42a8b85981f54f996b861a9709f126529e784a6f 100644 (file)
@@ -234,6 +234,11 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
         else:
             if sock is None:
                 raise ValueError('no path and sock were specified')
+            if (sock.family != socket.AF_UNIX or
+                    sock.type != socket.SOCK_STREAM):
+                raise ValueError(
+                    'A UNIX Domain Stream Socket was expected, got {!r}'
+                    .format(sock))
             sock.setblocking(False)
 
         transport, protocol = yield from self._create_connection_transport(
@@ -272,9 +277,11 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
                 raise ValueError(
                     'path was not specified, and no sock specified')
 
-            if sock.family != socket.AF_UNIX:
+            if (sock.family != socket.AF_UNIX or
+                    sock.type != socket.SOCK_STREAM):
                 raise ValueError(
-                    'A UNIX Domain Socket was expected, got {!r}'.format(sock))
+                    'A UNIX Domain Stream Socket was expected, got {!r}'
+                    .format(sock))
 
         server = base_events.Server(self, [sock])
         sock.listen(backlog)
index 088ef408e5217f0b95da587173818cb5d4f146c1..0d54e3a6d47fd103fc6d09021967aea0c45daea6 100644 (file)
@@ -273,7 +273,16 @@ class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
             coro = self.loop.create_unix_server(lambda: None, path=None,
                                                 sock=sock)
             with self.assertRaisesRegex(ValueError,
-                                        'A UNIX Domain Socket was expected'):
+                                        'A UNIX Domain Stream.*was expected'):
+                self.loop.run_until_complete(coro)
+
+    def test_create_unix_connection_path_inetsock(self):
+        sock = socket.socket()
+        with sock:
+            coro = self.loop.create_unix_connection(lambda: None, path=None,
+                                                    sock=sock)
+            with self.assertRaisesRegex(ValueError,
+                                        'A UNIX Domain Stream.*was expected'):
                 self.loop.run_until_complete(coro)
 
     @mock.patch('asyncio.unix_events.socket')