From: Peter Stokes <144905750+Dadeos-Menlo@users.noreply.github.com> Date: Tue, 9 Jul 2024 14:26:46 +0000 (+0100) Subject: Add support for binding Unix sockets in Linux's abstract namespace. (#3405) X-Git-Tag: v6.5.0b1~42 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=58a548272d17cac8895c269be8f2d5b5cf62e5e3;p=thirdparty%2Ftornado.git Add support for binding Unix sockets in Linux's abstract namespace. (#3405) --- diff --git a/tornado/netutil.py b/tornado/netutil.py index e74cf0f1..b6772086 100644 --- a/tornado/netutil.py +++ b/tornado/netutil.py @@ -209,17 +209,23 @@ if hasattr(socket, "AF_UNIX"): # Hurd doesn't support SO_REUSEADDR raise sock.setblocking(False) - try: - st = os.stat(file) - except FileNotFoundError: - pass - else: - if stat.S_ISSOCK(st.st_mode): - os.remove(file) + # File names comprising of an initial null-byte denote an abstract + # namespace, on Linux, and therefore are not subject to file system + # orientated processing. + if not file.startswith("\0"): + try: + st = os.stat(file) + except FileNotFoundError: + pass else: - raise ValueError("File %s exists and is not a socket", file) - sock.bind(file) - os.chmod(file, mode) + if stat.S_ISSOCK(st.st_mode): + os.remove(file) + else: + raise ValueError("File %s exists and is not a socket", file) + sock.bind(file) + os.chmod(file, mode) + else: + sock.bind(file) sock.listen(backlog) return sock diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index 5657f629..2112be66 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -44,6 +44,7 @@ import tempfile import textwrap import unittest import urllib.parse +import uuid from io import BytesIO import typing @@ -813,10 +814,6 @@ class ManualProtocolTest(HandlerBaseTestCase): self.assertEqual(self.fetch_json("/")["protocol"], "https") -@unittest.skipIf( - not hasattr(socket, "AF_UNIX") or sys.platform == "cygwin", - "unix sockets not supported on this platform", -) class UnixSocketTest(AsyncTestCase): """HTTPServers can listen on Unix sockets too. @@ -828,44 +825,72 @@ class UnixSocketTest(AsyncTestCase): an HTTP client, so we have to test this by hand. """ + address = "" + def setUp(self): + if type(self) is UnixSocketTest: + raise unittest.SkipTest("abstract base class") super().setUp() - self.tmpdir = tempfile.mkdtemp() - self.sockfile = os.path.join(self.tmpdir, "test.sock") - sock = netutil.bind_unix_socket(self.sockfile) app = Application([("/hello", HelloWorldRequestHandler)]) self.server = HTTPServer(app) - self.server.add_socket(sock) - self.stream = IOStream(socket.socket(socket.AF_UNIX)) - self.io_loop.run_sync(lambda: self.stream.connect(self.sockfile)) + self.server.add_socket(netutil.bind_unix_socket(self.address)) def tearDown(self): - self.stream.close() self.io_loop.run_sync(self.server.close_all_connections) self.server.stop() - shutil.rmtree(self.tmpdir) super().tearDown() @gen_test def test_unix_socket(self): - self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n") - response = yield self.stream.read_until(b"\r\n") - self.assertEqual(response, b"HTTP/1.1 200 OK\r\n") - header_data = yield self.stream.read_until(b"\r\n\r\n") - headers = HTTPHeaders.parse(header_data.decode("latin1")) - body = yield self.stream.read_bytes(int(headers["Content-Length"])) - self.assertEqual(body, b"Hello world") + with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream: + stream.connect(self.address) + stream.write(b"GET /hello HTTP/1.0\r\n\r\n") + response = yield stream.read_until(b"\r\n") + self.assertEqual(response, b"HTTP/1.1 200 OK\r\n") + header_data = yield stream.read_until(b"\r\n\r\n") + headers = HTTPHeaders.parse(header_data.decode("latin1")) + body = yield stream.read_bytes(int(headers["Content-Length"])) + self.assertEqual(body, b"Hello world") @gen_test def test_unix_socket_bad_request(self): # Unix sockets don't have remote addresses so they just return an # empty string. with ExpectLog(gen_log, "Malformed HTTP message from", level=logging.INFO): - self.stream.write(b"garbage\r\n\r\n") - response = yield self.stream.read_until_close() + with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream: + stream.connect(self.address) + stream.write(b"garbage\r\n\r\n") + response = yield stream.read_until_close() self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n") +@unittest.skipIf( + not hasattr(socket, "AF_UNIX") or sys.platform == "cygwin", + "unix sockets not supported on this platform", +) +class UnixSocketTestAbstract(UnixSocketTest): + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.address = os.path.join(self.tmpdir, "test.sock") + super().setUp() + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmpdir) + + +@unittest.skipIf( + not (hasattr(socket, "AF_UNIX") and sys.platform.startswith("linux")), + "abstract namespace unix sockets not supported on this platform", +) +class UnixSocketTestFile(UnixSocketTest): + + def setUp(self): + self.address = "\0" + uuid.uuid4().hex + super().setUp() + + class KeepAliveTest(AsyncHTTPTestCase): """Tests various scenarios for HTTP 1.1 keep-alive support.