From: Ben Darnell Date: Wed, 13 Oct 2010 00:59:29 +0000 (-0700) Subject: Refactor async connect logic from SimpleAsyncHTTPClient to IOStream. X-Git-Tag: v1.2.0~104 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f10c7176f11ee26de1aa063ee76af672ca5854d0;p=thirdparty%2Ftornado.git Refactor async connect logic from SimpleAsyncHTTPClient to IOStream. Closes #146. --- diff --git a/tornado/iostream.py b/tornado/iostream.py index ef1288b0a..1cd9722a9 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -37,16 +37,18 @@ class IOStream(object): a given delimiter, and read_bytes() reads until a specified number of bytes have been read from the socket. + The socket parameter may either be connected or unconnected. For + server operations the socket is the result of calling socket.accept(). + For client operations the socket is created with socket.socket(), + and may either be connected before passing it to the IOStream or + connected with IOStream.connect. + A very simple (and broken) HTTP client using this class: - import ioloop - import iostream + from tornado import ioloop + from tornado import iostream import socket - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) - s.connect(("friendfeed.com", 80)) - stream = IOStream(s) - def on_headers(data): headers = {} for line in data.split("\r\n"): @@ -60,7 +62,10 @@ class IOStream(object): stream.close() ioloop.IOLoop.instance().stop() - stream.write("GET / HTTP/1.0\r\n\r\n") + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + stream = iostream.IOStream(s) + stream.connect(("friendfeed.com", 80)) + stream.write("GET / HTTP/1.0\r\nHost: friendfeed.com\r\n\r\n") stream.read_until("\r\n\r\n", on_headers) ioloop.IOLoop.instance().start() @@ -79,11 +84,37 @@ class IOStream(object): self._read_callback = None self._write_callback = None self._close_callback = None + self._connect_callback = None + self._connecting = False self._state = self.io_loop.ERROR with stack_context.NullContext(): self.io_loop.add_handler( self.socket.fileno(), self._handle_events, self._state) + def connect(self, address, callback=None): + """Connects the socket to a remote address without blocking. + + May only be called if the socket passed to the constructor was + not previously connected. The address parameter is in the + same format as for socket.connect, i.e. a (host, port) tuple. + If callback is specified, it will be called when the + connection is completed. + + Note that it is safe to call IOStream.write while the + connection is pending, in which case the data will be written + as soon as the connection is ready (see the example in the + docstring for this class). + """ + self._connecting = True + try: + self.socket.connect(address) + except socket.error, e: + # In non-blocking mode connect() always raises EINPROGRESS + if e.errno != errno.EINPROGRESS: + raise + self._connect_callback = stack_context.wrap(callback) + self._add_io_state(self.io_loop.WRITE) + def read_until(self, delimiter, callback): """Call callback when we read the given delimiter.""" assert not self._read_callback, "Already reading" @@ -160,6 +191,8 @@ class IOStream(object): if not self.socket: return if events & self.io_loop.WRITE: + if self._connecting: + self._handle_connect() self._handle_write() if not self.socket: return @@ -278,6 +311,13 @@ class IOStream(object): return True return False + def _handle_connect(self): + if self._connect_callback is not None: + callback = self._connect_callback + self._connect_callback = None + self._run_callback(callback) + self._connecting = False + def _handle_write(self): while self._write_buffer: try: @@ -316,11 +356,17 @@ class IOStream(object): class SSLIOStream(IOStream): - """Sets up an SSL connection in a non-blocking manner""" + """A utility class to write to and read from a non-blocking socket. + + If the socket passed to the constructor is already connected, + it should be wrapped with + ssl.wrap_socket(sock, do_handshake_on_connect=False, **kwargs) + before constructing the SSLIOStream. Unconnected sockets will be + wrapped when IOStream.connect is finished. + """ def __init__(self, *args, **kwargs): super(SSLIOStream, self).__init__(*args, **kwargs) self._ssl_accepting = True - self._do_ssl_handshake() def _do_ssl_handshake(self): # Based on code from test_ssl.py in the python stdlib @@ -355,6 +401,13 @@ class SSLIOStream(IOStream): return super(SSLIOStream, self)._handle_write() + def _handle_connect(self): + # TODO(bdarnell): cert verification, etc + self.socket = ssl.wrap_socket(self.socket, + do_handshake_on_connect=False) + super(SSLIOStream, self)._handle_connect() + + def _read_from_socket(self): try: # SSLSocket objects have both a read() and recv() method, diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index b155a32a7..34a06e5de 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -68,34 +68,24 @@ class _HTTPConnection(object): self.chunks = None with stack_context.StackContext(self.cleanup): parsed = urlparse.urlsplit(self.request.url) - sock = socket.socket() - sock.setblocking(False) if ":" in parsed.netloc: host, _, port = parsed.netloc.partition(":") port = int(port) else: host = parsed.netloc port = 443 if parsed.scheme == "https" else 80 - try: - sock.connect((host, port)) - except socket.error, e: - # In non-blocking mode connect() always raises EINPROGRESS - if e.errno != errno.EINPROGRESS: - raise - # Wait for the non-blocking connect to complete - self.io_loop.add_handler(sock.fileno(), - functools.partial(self._on_connect, - sock, parsed), - IOLoop.WRITE) - - def _on_connect(self, sock, parsed, fd, events): - self.io_loop.remove_handler(fd) - if parsed.scheme == "https": - # TODO: cert verification, etc - sock = ssl.wrap_socket(sock, do_handshake_on_connect=False) - self.stream = SSLIOStream(sock, io_loop=self.io_loop) - else: - self.stream = IOStream(sock, io_loop=self.io_loop) + + if parsed.scheme == "https": + # TODO: cert verification, etc + self.stream = SSLIOStream(socket.socket(), + io_loop=self.io_loop) + else: + self.stream = IOStream(socket.socket(), + io_loop=self.io_loop) + self.stream.connect((host, port), + functools.partial(self._on_connect, parsed)) + + def _on_connect(self, parsed): if "Host" not in self.request.headers: self.request.headers["Host"] = parsed.netloc has_body = self.request.method in ("POST", "PUT")