From dd4b0fbf80ad60f6d61e3c191c94fba4a2148abc Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Thu, 7 Jul 2011 22:14:36 -0700 Subject: [PATCH] Add basic unix socket support. tornado.netutil.bind_unix_socket can create non-blocking listening unix sockets, and HTTPServer can use them. (no client-side support for this yet) This is useful e.g. with nginx proxying incoming TCP traffic to a backend over a unix socket (which may be easier to manage than a set of TCP ports) --- tornado/httpserver.py | 7 +++++ tornado/iostream.py | 1 + tornado/netutil.py | 33 ++++++++++++++++++++ tornado/test/httpserver_test.py | 54 ++++++++++++++++++++++++++++++--- 4 files changed, 91 insertions(+), 4 deletions(-) diff --git a/tornado/httpserver.py b/tornado/httpserver.py index eddb75122..6d58d0d2c 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -170,6 +170,10 @@ class HTTPServer(object): self.io_loop.add_handler(sock.fileno(), self._handle_events, ioloop.IOLoop.READ) + def add_socket(self, socket): + """Singular version of `add_sockets`. Takes a single socket object.""" + self.add_sockets([socket]) + def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128): """Binds this server to the given port on the given address. @@ -264,6 +268,9 @@ class HTTPServer(object): stream = iostream.SSLIOStream(connection, io_loop=self.io_loop) else: stream = iostream.IOStream(connection, io_loop=self.io_loop) + if connection.family not in (socket.AF_INET, socket.AF_INET6): + # Unix (or other) socket; fake the remote address + address = ('0.0.0.0', 0) HTTPConnection(stream, address, self.request_callback, self.no_keep_alive, self.xheaders) except Exception: diff --git a/tornado/iostream.py b/tornado/iostream.py index abf34a2a5..f2af5375f 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -141,6 +141,7 @@ class IOStream(object): def read_bytes(self, num_bytes, callback): """Call callback when we read the given number of bytes.""" assert not self._read_callback, "Already reading" + assert isinstance(num_bytes, int) self._read_bytes = num_bytes self._read_callback = stack_context.wrap(callback) while True: diff --git a/tornado/netutil.py b/tornado/netutil.py index 6d6c5ce11..05921dfae 100644 --- a/tornado/netutil.py +++ b/tornado/netutil.py @@ -16,7 +16,10 @@ """Miscellaneous network utility code.""" +import errno +import os import socket +import stat from tornado.platform.auto import set_close_exec @@ -69,3 +72,33 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128): sock.listen(backlog) sockets.append(sock) return sockets + +if hasattr(socket, 'AF_UNIX'): + def bind_unix_socket(file, mode=0600, backlog=128): + """Creates a listening unix socket. + + If a socket with the given name already exists, it will be deleted. + If any other file with that name exists, an exception will be + raised. + + Returns a socket object (not a list of socket objects like + `bind_sockets`) + """ + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + set_close_exec(sock.fileno()) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(0) + try: + st = os.stat(file) + except OSError, err: + if err.errno != errno.ENOENT: + raise + else: + if stat.S_ISSOCK(st.st_mode): + os.remove(file) + else: + raise ValueError("File %s exists and is not a socket", file) + sock.bind(file) + os.chmod(file, mode) + sock.listen(backlog) + return sock diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index fa48f5c21..57d8f4686 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -1,12 +1,18 @@ #!/usr/bin/env python -from tornado import httpclient, simple_httpclient +from tornado import httpclient, simple_httpclient, netutil from tornado.escape import json_decode, utf8, _unicode, recursive_unicode +from tornado.httpserver import HTTPServer +from tornado.httputil import HTTPHeaders +from tornado.iostream import IOStream from tornado.simple_httpclient import SimpleAsyncHTTPClient -from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase +from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, AsyncTestCase from tornado.util import b, bytes_type from tornado.web import Application, RequestHandler import os +import shutil +import socket +import tempfile try: import ssl @@ -14,8 +20,11 @@ except ImportError: ssl = None class HelloWorldRequestHandler(RequestHandler): + def initialize(self, protocol="http"): + self.expected_protocol = protocol + def get(self): - assert self.request.protocol == "https" + assert self.request.protocol == self.expected_protocol self.finish("Hello world") def post(self): @@ -31,7 +40,8 @@ class SSLTest(AsyncHTTPTestCase, LogTrapTestCase): force_instance=True) def get_app(self): - return Application([('/', HelloWorldRequestHandler)]) + return Application([('/', HelloWorldRequestHandler, + dict(protocol="https"))]) def get_httpserver_options(self): # Testing keys were generated with: @@ -189,3 +199,39 @@ class HTTPServerTest(AsyncHTTPTestCase, LogTrapTestCase): data = json_decode(response.body) self.assertEqual(data, {}) +class UnixSocketTest(AsyncTestCase, LogTrapTestCase): + """HTTPServers can listen on Unix sockets too. + + Why would you want to do this? Nginx can proxy to backends listening + on unix sockets, for one thing (and managing a namespace for unix + sockets can be easier than managing a bunch of TCP port numbers). + + Unfortunately, there's no way to specify a unix socket in a url for + an HTTP client, so we have to test this by hand. + """ + def setUp(self): + super(UnixSocketTest, self).setUp() + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir) + super(UnixSocketTest, self).tearDown() + + def test_unix_socket(self): + sockfile = os.path.join(self.tmpdir, "test.sock") + sock = netutil.bind_unix_socket(sockfile) + app = Application([("/hello", HelloWorldRequestHandler)]) + server = HTTPServer(app, io_loop=self.io_loop) + server.add_socket(sock) + stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop) + stream.connect(sockfile, self.stop) + self.wait() + stream.write(b("GET /hello HTTP/1.0\r\n\r\n")) + stream.read_until(b("\r\n"), self.stop) + response = self.wait() + self.assertEqual(response, b("HTTP/1.0 200 OK\r\n")) + stream.read_until(b("\r\n\r\n"), self.stop) + headers = HTTPHeaders.parse(self.wait().decode('latin1')) + stream.read_bytes(int(headers["Content-Length"]), self.stop) + body = self.wait() + self.assertEqual(body, b("Hello world")) -- 2.47.3