]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add basic unix socket support.
authorBen Darnell <ben@bendarnell.com>
Fri, 8 Jul 2011 05:14:36 +0000 (22:14 -0700)
committerBen Darnell <ben@bendarnell.com>
Fri, 8 Jul 2011 05:14:36 +0000 (22:14 -0700)
tornado.netutil.bind_unix_socket can create non-blocking listening unix
sockets, and HTTPServer can use them.  (no client-side support for this yet)
This is useful e.g. with nginx proxying incoming TCP traffic to
a backend over a unix socket (which may be easier to manage than a set of
TCP ports)

tornado/httpserver.py
tornado/iostream.py
tornado/netutil.py
tornado/test/httpserver_test.py

index eddb751228a949cec2d5096f7453c653394e1129..6d58d0d2c7c14f4914678050b41fdb26c3858b35 100644 (file)
@@ -170,6 +170,10 @@ class HTTPServer(object):
             self.io_loop.add_handler(sock.fileno(), self._handle_events,
                                      ioloop.IOLoop.READ)
 
+    def add_socket(self, socket):
+        """Singular version of `add_sockets`.  Takes a single socket object."""
+        self.add_sockets([socket])
+
     def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128):
         """Binds this server to the given port on the given address.
 
@@ -264,6 +268,9 @@ class HTTPServer(object):
                     stream = iostream.SSLIOStream(connection, io_loop=self.io_loop)
                 else:
                     stream = iostream.IOStream(connection, io_loop=self.io_loop)
+                if connection.family not in (socket.AF_INET, socket.AF_INET6):
+                    # Unix (or other) socket; fake the remote address
+                    address = ('0.0.0.0', 0)
                 HTTPConnection(stream, address, self.request_callback,
                                self.no_keep_alive, self.xheaders)
             except Exception:
index abf34a2a58473c162629aa649b08326e7eeff793..f2af5375f7613afdcdb21154d414c53d6a5cfad2 100644 (file)
@@ -141,6 +141,7 @@ class IOStream(object):
     def read_bytes(self, num_bytes, callback):
         """Call callback when we read the given number of bytes."""
         assert not self._read_callback, "Already reading"
+        assert isinstance(num_bytes, int)
         self._read_bytes = num_bytes
         self._read_callback = stack_context.wrap(callback)
         while True:
index 6d6c5ce11007cf47041a570727a31333605a41ec..05921dfaeba40c246ac2fdeee9271f9a47efa767 100644 (file)
 
 """Miscellaneous network utility code."""
 
+import errno
+import os
 import socket
+import stat
 
 from tornado.platform.auto import set_close_exec
 
@@ -69,3 +72,33 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128):
         sock.listen(backlog)
         sockets.append(sock)
     return sockets
+
+if hasattr(socket, 'AF_UNIX'):
+    def bind_unix_socket(file, mode=0600, backlog=128):
+        """Creates a listening unix socket.
+
+        If a socket with the given name already exists, it will be deleted.
+        If any other file with that name exists, an exception will be
+        raised.
+
+        Returns a socket object (not a list of socket objects like 
+        `bind_sockets`)
+        """
+        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+        set_close_exec(sock.fileno())
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+        sock.setblocking(0)
+        try:
+            st = os.stat(file)
+        except OSError, err:
+            if err.errno != errno.ENOENT:
+                raise
+        else:
+            if stat.S_ISSOCK(st.st_mode):
+                os.remove(file)
+            else:
+                raise ValueError("File %s exists and is not a socket", file)
+        sock.bind(file)
+        os.chmod(file, mode)
+        sock.listen(backlog)
+        return sock
index fa48f5c21292c3e23f1323305901a03d6b235962..57d8f4686a6183ec045ccfa8d70766fb14aa6f13 100644 (file)
@@ -1,12 +1,18 @@
 #!/usr/bin/env python
 
-from tornado import httpclient, simple_httpclient
+from tornado import httpclient, simple_httpclient, netutil
 from tornado.escape import json_decode, utf8, _unicode, recursive_unicode
+from tornado.httpserver import HTTPServer
+from tornado.httputil import HTTPHeaders
+from tornado.iostream import IOStream
 from tornado.simple_httpclient import SimpleAsyncHTTPClient
-from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase
+from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, AsyncTestCase
 from tornado.util import b, bytes_type
 from tornado.web import Application, RequestHandler
 import os
+import shutil
+import socket
+import tempfile
 
 try:
     import ssl
@@ -14,8 +20,11 @@ except ImportError:
     ssl = None
 
 class HelloWorldRequestHandler(RequestHandler):
+    def initialize(self, protocol="http"):
+        self.expected_protocol = protocol
+
     def get(self):
-        assert self.request.protocol == "https"
+        assert self.request.protocol == self.expected_protocol
         self.finish("Hello world")
 
     def post(self):
@@ -31,7 +40,8 @@ class SSLTest(AsyncHTTPTestCase, LogTrapTestCase):
                                                  force_instance=True)
 
     def get_app(self):
-        return Application([('/', HelloWorldRequestHandler)])
+        return Application([('/', HelloWorldRequestHandler, 
+                             dict(protocol="https"))])
 
     def get_httpserver_options(self):
         # Testing keys were generated with:
@@ -189,3 +199,39 @@ class HTTPServerTest(AsyncHTTPTestCase, LogTrapTestCase):
         data = json_decode(response.body)
         self.assertEqual(data, {})
 
+class UnixSocketTest(AsyncTestCase, LogTrapTestCase):
+    """HTTPServers can listen on Unix sockets too.
+
+    Why would you want to do this?  Nginx can proxy to backends listening
+    on unix sockets, for one thing (and managing a namespace for unix
+    sockets can be easier than managing a bunch of TCP port numbers).
+
+    Unfortunately, there's no way to specify a unix socket in a url for
+    an HTTP client, so we have to test this by hand.
+    """
+    def setUp(self):
+        super(UnixSocketTest, self).setUp()
+        self.tmpdir = tempfile.mkdtemp()
+
+    def tearDown(self):
+        shutil.rmtree(self.tmpdir)
+        super(UnixSocketTest, self).tearDown()
+
+    def test_unix_socket(self):
+        sockfile = os.path.join(self.tmpdir, "test.sock")
+        sock = netutil.bind_unix_socket(sockfile)
+        app = Application([("/hello", HelloWorldRequestHandler)])
+        server = HTTPServer(app, io_loop=self.io_loop)
+        server.add_socket(sock)
+        stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
+        stream.connect(sockfile, self.stop)
+        self.wait()
+        stream.write(b("GET /hello HTTP/1.0\r\n\r\n"))
+        stream.read_until(b("\r\n"), self.stop)
+        response = self.wait()
+        self.assertEqual(response, b("HTTP/1.0 200 OK\r\n"))
+        stream.read_until(b("\r\n\r\n"), self.stop)
+        headers = HTTPHeaders.parse(self.wait().decode('latin1'))
+        stream.read_bytes(int(headers["Content-Length"]), self.stop)
+        body = self.wait()
+        self.assertEqual(body, b("Hello world"))