]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Refactor async connect logic from SimpleAsyncHTTPClient to IOStream.
authorBen Darnell <ben@bendarnell.com>
Wed, 13 Oct 2010 00:59:29 +0000 (17:59 -0700)
committerBen Darnell <ben@bendarnell.com>
Wed, 13 Oct 2010 00:59:29 +0000 (17:59 -0700)
Closes #146.

tornado/iostream.py
tornado/simple_httpclient.py

index ef1288b0a4d82b7f4abb3c320d26d5662b8c7bc3..1cd9722a9fc79233cfca47fb119a554d30dd1852 100644 (file)
@@ -37,16 +37,18 @@ class IOStream(object):
     a given delimiter, and read_bytes() reads until a specified number
     of bytes have been read from the socket.
 
+    The socket parameter may either be connected or unconnected.  For
+    server operations the socket is the result of calling socket.accept().
+    For client operations the socket is created with socket.socket(),
+    and may either be connected before passing it to the IOStream or
+    connected with IOStream.connect.
+
     A very simple (and broken) HTTP client using this class:
 
-        import ioloop
-        import iostream
+        from tornado import ioloop
+        from tornado import iostream
         import socket
 
-        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
-        s.connect(("friendfeed.com", 80))
-        stream = IOStream(s)
-
         def on_headers(data):
             headers = {}
             for line in data.split("\r\n"):
@@ -60,7 +62,10 @@ class IOStream(object):
             stream.close()
             ioloop.IOLoop.instance().stop()
 
-        stream.write("GET / HTTP/1.0\r\n\r\n")
+        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+        stream = iostream.IOStream(s)
+        stream.connect(("friendfeed.com", 80))
+        stream.write("GET / HTTP/1.0\r\nHost: friendfeed.com\r\n\r\n")
         stream.read_until("\r\n\r\n", on_headers)
         ioloop.IOLoop.instance().start()
 
@@ -79,11 +84,37 @@ class IOStream(object):
         self._read_callback = None
         self._write_callback = None
         self._close_callback = None
+        self._connect_callback = None
+        self._connecting = False
         self._state = self.io_loop.ERROR
         with stack_context.NullContext():
             self.io_loop.add_handler(
                 self.socket.fileno(), self._handle_events, self._state)
 
+    def connect(self, address, callback=None):
+        """Connects the socket to a remote address without blocking.
+
+        May only be called if the socket passed to the constructor was
+        not previously connected.  The address parameter is in the
+        same format as for socket.connect, i.e. a (host, port) tuple.
+        If callback is specified, it will be called when the
+        connection is completed.
+
+        Note that it is safe to call IOStream.write while the
+        connection is pending, in which case the data will be written
+        as soon as the connection is ready (see the example in the
+        docstring for this class).
+        """
+        self._connecting = True
+        try:
+            self.socket.connect(address)
+        except socket.error, e:
+            # In non-blocking mode connect() always raises EINPROGRESS
+            if e.errno != errno.EINPROGRESS:
+                raise
+        self._connect_callback = stack_context.wrap(callback)
+        self._add_io_state(self.io_loop.WRITE)
+
     def read_until(self, delimiter, callback):
         """Call callback when we read the given delimiter."""
         assert not self._read_callback, "Already reading"
@@ -160,6 +191,8 @@ class IOStream(object):
         if not self.socket:
             return
         if events & self.io_loop.WRITE:
+            if self._connecting:
+                self._handle_connect()
             self._handle_write()
         if not self.socket:
             return
@@ -278,6 +311,13 @@ class IOStream(object):
                 return True
         return False
 
+    def _handle_connect(self):
+        if self._connect_callback is not None:
+            callback = self._connect_callback
+            self._connect_callback = None
+            self._run_callback(callback)
+        self._connecting = False
+
     def _handle_write(self):
         while self._write_buffer:
             try:
@@ -316,11 +356,17 @@ class IOStream(object):
 
 
 class SSLIOStream(IOStream):
-    """Sets up an SSL connection in a non-blocking manner"""
+    """A utility class to write to and read from a non-blocking socket.
+
+    If the socket passed to the constructor is already connected,
+    it should be wrapped with
+        ssl.wrap_socket(sock, do_handshake_on_connect=False, **kwargs)
+    before constructing the SSLIOStream.  Unconnected sockets will be
+    wrapped when IOStream.connect is finished.
+    """
     def __init__(self, *args, **kwargs):
         super(SSLIOStream, self).__init__(*args, **kwargs)
         self._ssl_accepting = True
-        self._do_ssl_handshake()
 
     def _do_ssl_handshake(self):
         # Based on code from test_ssl.py in the python stdlib
@@ -355,6 +401,13 @@ class SSLIOStream(IOStream):
             return
         super(SSLIOStream, self)._handle_write()
 
+    def _handle_connect(self):
+        # TODO(bdarnell): cert verification, etc
+        self.socket = ssl.wrap_socket(self.socket,
+                                      do_handshake_on_connect=False)
+        super(SSLIOStream, self)._handle_connect()
+
+
     def _read_from_socket(self):
         try:
             # SSLSocket objects have both a read() and recv() method,
index b155a32a7057d8b06722caaa1cb77103eca3f46c..34a06e5de6064315d834c9ee522083676db57011 100644 (file)
@@ -68,34 +68,24 @@ class _HTTPConnection(object):
         self.chunks = None
         with stack_context.StackContext(self.cleanup):
             parsed = urlparse.urlsplit(self.request.url)
-            sock = socket.socket()
-            sock.setblocking(False)
             if ":" in parsed.netloc:
                 host, _, port = parsed.netloc.partition(":")
                 port = int(port)
             else:
                 host = parsed.netloc
                 port = 443 if parsed.scheme == "https" else 80
-            try:
-                sock.connect((host, port))
-            except socket.error, e:
-                # In non-blocking mode connect() always raises EINPROGRESS
-                if e.errno != errno.EINPROGRESS:
-                    raise
-            # Wait for the non-blocking connect to complete
-            self.io_loop.add_handler(sock.fileno(),
-                                     functools.partial(self._on_connect,
-                                                       sock, parsed),
-                                     IOLoop.WRITE)
-
-    def _on_connect(self, sock, parsed, fd, events):
-        self.io_loop.remove_handler(fd)
-        if parsed.scheme == "https":
-            # TODO: cert verification, etc
-            sock = ssl.wrap_socket(sock, do_handshake_on_connect=False)
-            self.stream = SSLIOStream(sock, io_loop=self.io_loop)
-        else:
-            self.stream = IOStream(sock, io_loop=self.io_loop)
+
+            if parsed.scheme == "https":
+                # TODO: cert verification, etc
+                self.stream = SSLIOStream(socket.socket(),
+                                          io_loop=self.io_loop)
+            else:
+                self.stream = IOStream(socket.socket(),
+                                       io_loop=self.io_loop)
+            self.stream.connect((host, port),
+                                functools.partial(self._on_connect, parsed))
+
+    def _on_connect(self, parsed):
         if "Host" not in self.request.headers:
             self.request.headers["Host"] = parsed.netloc
         has_body = self.request.method in ("POST", "PUT")