]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add support for binding Unix sockets in Linux's abstract namespace. (#3405)
authorPeter Stokes <144905750+Dadeos-Menlo@users.noreply.github.com>
Tue, 9 Jul 2024 14:26:46 +0000 (15:26 +0100)
committerGitHub <noreply@github.com>
Tue, 9 Jul 2024 14:26:46 +0000 (10:26 -0400)
tornado/netutil.py
tornado/test/httpserver_test.py

index e74cf0f1ec1e32e8e2a0cd97f597d314bc12e786..b6772086af46203303b2aa066463ffe3a57213c6 100644 (file)
@@ -209,17 +209,23 @@ if hasattr(socket, "AF_UNIX"):
                 # Hurd doesn't support SO_REUSEADDR
                 raise
         sock.setblocking(False)
-        try:
-            st = os.stat(file)
-        except FileNotFoundError:
-            pass
-        else:
-            if stat.S_ISSOCK(st.st_mode):
-                os.remove(file)
+        # File names comprising of an initial null-byte denote an abstract
+        # namespace, on Linux, and therefore are not subject to file system
+        # orientated processing.
+        if not file.startswith("\0"):
+            try:
+                st = os.stat(file)
+            except FileNotFoundError:
+                pass
             else:
-                raise ValueError("File %s exists and is not a socket", file)
-        sock.bind(file)
-        os.chmod(file, mode)
+                if stat.S_ISSOCK(st.st_mode):
+                    os.remove(file)
+                else:
+                    raise ValueError("File %s exists and is not a socket", file)
+            sock.bind(file)
+            os.chmod(file, mode)
+        else:
+            sock.bind(file)
         sock.listen(backlog)
         return sock
 
index 5657f629c3ce899d7ab90859c9ff944eac6c45c9..2112be668c5a36ee3a2d035cb85dfdc126c94dc0 100644 (file)
@@ -44,6 +44,7 @@ import tempfile
 import textwrap
 import unittest
 import urllib.parse
+import uuid
 from io import BytesIO
 
 import typing
@@ -813,10 +814,6 @@ class ManualProtocolTest(HandlerBaseTestCase):
         self.assertEqual(self.fetch_json("/")["protocol"], "https")
 
 
-@unittest.skipIf(
-    not hasattr(socket, "AF_UNIX") or sys.platform == "cygwin",
-    "unix sockets not supported on this platform",
-)
 class UnixSocketTest(AsyncTestCase):
     """HTTPServers can listen on Unix sockets too.
 
@@ -828,44 +825,72 @@ class UnixSocketTest(AsyncTestCase):
     an HTTP client, so we have to test this by hand.
     """
 
+    address = ""
+
     def setUp(self):
+        if type(self) is UnixSocketTest:
+            raise unittest.SkipTest("abstract base class")
         super().setUp()
-        self.tmpdir = tempfile.mkdtemp()
-        self.sockfile = os.path.join(self.tmpdir, "test.sock")
-        sock = netutil.bind_unix_socket(self.sockfile)
         app = Application([("/hello", HelloWorldRequestHandler)])
         self.server = HTTPServer(app)
-        self.server.add_socket(sock)
-        self.stream = IOStream(socket.socket(socket.AF_UNIX))
-        self.io_loop.run_sync(lambda: self.stream.connect(self.sockfile))
+        self.server.add_socket(netutil.bind_unix_socket(self.address))
 
     def tearDown(self):
-        self.stream.close()
         self.io_loop.run_sync(self.server.close_all_connections)
         self.server.stop()
-        shutil.rmtree(self.tmpdir)
         super().tearDown()
 
     @gen_test
     def test_unix_socket(self):
-        self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
-        response = yield self.stream.read_until(b"\r\n")
-        self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
-        header_data = yield self.stream.read_until(b"\r\n\r\n")
-        headers = HTTPHeaders.parse(header_data.decode("latin1"))
-        body = yield self.stream.read_bytes(int(headers["Content-Length"]))
-        self.assertEqual(body, b"Hello world")
+        with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream:
+            stream.connect(self.address)
+            stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
+            response = yield stream.read_until(b"\r\n")
+            self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
+            header_data = yield stream.read_until(b"\r\n\r\n")
+            headers = HTTPHeaders.parse(header_data.decode("latin1"))
+            body = yield stream.read_bytes(int(headers["Content-Length"]))
+            self.assertEqual(body, b"Hello world")
 
     @gen_test
     def test_unix_socket_bad_request(self):
         # Unix sockets don't have remote addresses so they just return an
         # empty string.
         with ExpectLog(gen_log, "Malformed HTTP message from", level=logging.INFO):
-            self.stream.write(b"garbage\r\n\r\n")
-            response = yield self.stream.read_until_close()
+            with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream:
+                stream.connect(self.address)
+                stream.write(b"garbage\r\n\r\n")
+                response = yield stream.read_until_close()
         self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")
 
 
+@unittest.skipIf(
+    not hasattr(socket, "AF_UNIX") or sys.platform == "cygwin",
+    "unix sockets not supported on this platform",
+)
+class UnixSocketTestAbstract(UnixSocketTest):
+
+    def setUp(self):
+        self.tmpdir = tempfile.mkdtemp()
+        self.address = os.path.join(self.tmpdir, "test.sock")
+        super().setUp()
+
+    def tearDown(self):
+        super().tearDown()
+        shutil.rmtree(self.tmpdir)
+
+
+@unittest.skipIf(
+    not (hasattr(socket, "AF_UNIX") and sys.platform.startswith("linux")),
+    "abstract namespace unix sockets not supported on this platform",
+)
+class UnixSocketTestFile(UnixSocketTest):
+
+    def setUp(self):
+        self.address = "\0" + uuid.uuid4().hex
+        super().setUp()
+
+
 class KeepAliveTest(AsyncHTTPTestCase):
     """Tests various scenarios for HTTP 1.1 keep-alive support.