From: Ben Darnell Date: Sat, 3 May 2014 15:40:01 +0000 (-0400) Subject: Refactor connection logic from simple_httpclient to a new tcpclient module. X-Git-Tag: v4.0.0b1~65 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2278cef74f55c4f01bd7a4baebc73c2052cb6849;p=thirdparty%2Ftornado.git Refactor connection logic from simple_httpclient to a new tcpclient module. This is preparation for introducing connection pooling and better handling of ipv6. If an IOStream was closed due to an exception, its Futures will now raise that exception instead of StreamClosedError. --- diff --git a/tornado/http1connection.py b/tornado/http1connection.py index b8109eed4..d71a3244a 100644 --- a/tornado/http1connection.py +++ b/tornado/http1connection.py @@ -604,7 +604,8 @@ class HTTP1ServerConnection(object): request_delegate = delegate.start_request(self, conn) try: ret = yield conn.read_response(request_delegate) - except iostream.StreamClosedError: + except (iostream.StreamClosedError, + iostream.UnsatisfiableReadError): return except Exception: # TODO: this is probably too broad; it would be better to diff --git a/tornado/iostream.py b/tornado/iostream.py index c8c074943..b5440b2a8 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -72,6 +72,7 @@ class StreamClosedError(IOError): pass + class UnsatisfiableReadError(Exception): """Exception raised when a read cannot be satisfied. @@ -364,7 +365,7 @@ class BaseIOStream(object): futures.append(self._connect_future) self._connect_future = None for future in futures: - future.set_exception(StreamClosedError()) + future.set_exception(self.error or StreamClosedError()) if self._close_callback is not None: cb = self._close_callback self._close_callback = None @@ -408,13 +409,19 @@ class BaseIOStream(object): gen_log.warning("Got events for closed stream %s", fd) return try: + if self._connecting: + # Most IOLoops will report a write failed connect + # with the WRITE event, but SelectIOLoop reports a + # READ as well so we must check for connecting before + # either. + self._handle_connect() + if self.closed(): + return if events & self.io_loop.READ: self._handle_read() if self.closed(): return if events & self.io_loop.WRITE: - if self._connecting: - self._handle_connect() self._handle_write() if self.closed(): return diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index b7be7950f..06d2ca801 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -10,6 +10,7 @@ from tornado.iostream import IOStream, SSLIOStream, StreamClosedError from tornado.netutil import Resolver, OverrideResolver from tornado.log import gen_log from tornado import stack_context +from tornado.tcpclient import TCPClient import base64 import collections @@ -88,6 +89,8 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): self.waiting = {} self.max_buffer_size = max_buffer_size self.max_header_size = max_header_size + # TCPClient could create a Resolver for us, but we have to do it + # ourselves to support hostname_mapping. if resolver: self.resolver = resolver self.own_resolver = False @@ -97,11 +100,13 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): if hostname_mapping is not None: self.resolver = OverrideResolver(resolver=self.resolver, mapping=hostname_mapping) + self.tcp_client = TCPClient(resolver=self.resolver, io_loop=io_loop) def close(self): super(SimpleAsyncHTTPClient, self).close() if self.own_resolver: self.resolver.close() + self.tcp_client.close() def fetch_impl(self, request, callback): key = object() @@ -133,7 +138,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient): def _handle_request(self, request, release_callback, final_callback): _HTTPConnection(self.io_loop, self, request, release_callback, - final_callback, self.max_buffer_size, self.resolver, + final_callback, self.max_buffer_size, self.tcp_client, self.max_header_size) def _release_fetch(self, key): @@ -161,7 +166,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) def __init__(self, io_loop, client, request, release_callback, - final_callback, max_buffer_size, resolver, + final_callback, max_buffer_size, tcp_client, max_header_size): self.start_time = io_loop.time() self.io_loop = io_loop @@ -170,7 +175,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): self.release_callback = release_callback self.final_callback = final_callback self.max_buffer_size = max_buffer_size - self.resolver = resolver + self.tcp_client = tcp_client self.max_header_size = max_header_size self.code = None self.headers = None @@ -208,28 +213,19 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): # so restrict to ipv4 by default. af = socket.AF_INET + ssl_options = self._get_ssl_options(self.parsed.scheme) + timeout = min(self.request.connect_timeout, self.request.request_timeout) if timeout: self._timeout = self.io_loop.add_timeout( self.start_time + timeout, stack_context.wrap(self._on_timeout)) - self.resolver.resolve(host, port, af, callback=self._on_resolve) + self.tcp_client.connect(host, port, af=af, + ssl_options=ssl_options, + callback=self._on_connect) - def _on_resolve(self, addrinfo): - if self.final_callback is None: - # final_callback is cleared if we've hit our timeout - return - self.stream = self._create_stream(addrinfo) - self.stream.set_close_callback(self._on_close) - # ipv6 addresses are broken (in self.parsed.hostname) until - # 2.7, here is correctly parsed value calculated in __init__ - self._sockaddr = addrinfo[0][1] - self.stream.connect(self._sockaddr, self._on_connect, - server_hostname=self.parsed_hostname) - - def _create_stream(self, addrinfo): - af = addrinfo[0][0] - if self.parsed.scheme == "https": + def _get_ssl_options(self, scheme): + if scheme == "https": ssl_options = {} if self.request.validate_cert: ssl_options["cert_reqs"] = ssl.CERT_REQUIRED @@ -262,15 +258,8 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): # of openssl, but python 2.6 doesn't expose version # information. ssl_options["ssl_version"] = ssl.PROTOCOL_TLSv1 - - return SSLIOStream(socket.socket(af), - io_loop=self.io_loop, - ssl_options=ssl_options, - max_buffer_size=self.max_buffer_size) - else: - return IOStream(socket.socket(af), - io_loop=self.io_loop, - max_buffer_size=self.max_buffer_size) + return ssl_options + return None def _on_timeout(self): self._timeout = None @@ -282,7 +271,13 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): self.io_loop.remove_timeout(self._timeout) self._timeout = None - def _on_connect(self): + def _on_connect(self, stream): + if self.final_callback is None: + # final_callback is cleared if we've hit our timeout. + stream.close() + return + self.stream = stream + self.stream.set_close_callback(self._on_close) self._remove_timeout() if self.final_callback is None: return diff --git a/tornado/tcpclient.py b/tornado/tcpclient.py new file mode 100644 index 000000000..e41c39094 --- /dev/null +++ b/tornado/tcpclient.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# +# Copyright 2014 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""A non-blocking TCP connection factory. +""" +from __future__ import absolute_import, division, print_function, with_statement + +import socket + +from tornado.ioloop import IOLoop +from tornado.iostream import IOStream, SSLIOStream +from tornado import gen +from tornado.netutil import Resolver + +class TCPClient(object): + """A non-blocking TCP connection factory. + """ + def __init__(self, resolver=None, io_loop=None): + self.io_loop = io_loop or IOLoop.current() + if resolver is not None: + self.resolver = resolver + self._own_resolver = False + else: + self.resolver = Resolver(io_loop=io_loop) + self._own_resolver = True + + def close(self): + if self._own_resolver: + self.resolver.close() + + @gen.coroutine + def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None, + max_buffer_size=None): + """Connect to the given host and port. + + Asynchronously returns an `.IOStream` (or `.SSLIOStream` if + ``ssl_options`` is not None). + """ + addrinfo = yield self.resolver.resolve(host, port, af) + af, addr = addrinfo[0] + stream = self._create_stream(af, ssl_options, max_buffer_size) + yield stream.connect(addr, server_hostname=host) + raise gen.Return(stream) + + def _create_stream(self, af, ssl_options, max_buffer_size): + if ssl_options is None: + return IOStream(socket.socket(af), + io_loop=self.io_loop, + max_buffer_size=max_buffer_size) + else: + return SSLIOStream(socket.socket(af), + io_loop=self.io_loop, + ssl_options=ssl_options, + max_buffer_size=max_buffer_size) diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index f9e920cd9..8387c23cf 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -9,7 +9,7 @@ from tornado.http1connection import HTTP1Connection from tornado.httpserver import HTTPServer from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine from tornado.iostream import IOStream -from tornado.log import gen_log +from tornado.log import gen_log, app_log from tornado.netutil import ssl_options_to_context from tornado.simple_httpclient import SimpleAsyncHTTPClient from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test @@ -111,11 +111,15 @@ class SSLTestMixin(object): # connection, rather than waiting for a timeout or otherwise # misbehaving. with ExpectLog(gen_log, '(SSL Error|uncaught exception)'): - self.http_client.fetch(self.get_url("/").replace('https:', 'http:'), - self.stop, - request_timeout=3600, - connect_timeout=3600) - response = self.wait() + # TODO: this should go to gen_log, not app_log. See TODO + # in http1connection.py (_server_request_loop) + with ExpectLog(app_log, 'Uncaught exception', required=False): + self.http_client.fetch( + self.get_url("/").replace('https:', 'http:'), + self.stop, + request_timeout=3600, + connect_timeout=3600) + response = self.wait() self.assertEqual(response.code, 599) # Python's SSL implementation differs significantly between versions. diff --git a/tornado/test/runtests.py b/tornado/test/runtests.py index a1fb32952..c1c5746b0 100644 --- a/tornado/test/runtests.py +++ b/tornado/test/runtests.py @@ -40,6 +40,7 @@ TEST_MODULES = [ 'tornado.test.process_test', 'tornado.test.simple_httpclient_test', 'tornado.test.stack_context_test', + 'tornado.test.tcpclient_test', 'tornado.test.template_test', 'tornado.test.testing_test', 'tornado.test.twisted_test', diff --git a/tornado/test/tcpclient_test.py b/tornado/test/tcpclient_test.py new file mode 100644 index 000000000..73bc3c103 --- /dev/null +++ b/tornado/test/tcpclient_test.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# +# Copyright 2014 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, with_statement + +from contextlib import closing + +from tornado.log import gen_log +from tornado.tcpclient import TCPClient +from tornado.tcpserver import TCPServer +from tornado.testing import AsyncTestCase, bind_unused_port, gen_test, ExpectLog + +class TestTCPServer(TCPServer): + def __init__(self): + super(TestTCPServer, self).__init__() + self.streams = [] + socket, self.port = bind_unused_port() + self.add_socket(socket) + + def handle_stream(self, stream, address): + self.streams.append(stream) + + def stop(self): + super(TestTCPServer, self).stop() + for stream in self.streams: + stream.close() + +class TCPClientTest(AsyncTestCase): + def setUp(self): + super(TCPClientTest, self).setUp() + self.server = TestTCPServer() + self.port = self.server.port + self.client = TCPClient() + + def tearDown(self): + self.client.close() + self.server.stop() + super(TCPClientTest, self).tearDown() + + @gen_test + def test_connect_ipv4(self): + stream = yield self.client.connect('127.0.0.1', self.port) + with closing(stream): + stream.write(b"hello") + data = yield self.server.streams[0].read_bytes(5) + self.assertEqual(data, b"hello") + + @gen_test + def test_refused_ipv4(self): + sock, port = bind_unused_port() + sock.close() + with ExpectLog(gen_log, 'Connect error'): + with self.assertRaises(IOError): + yield self.client.connect('127.0.0.1', port) diff --git a/tornado/websocket.py b/tornado/websocket.py index 218413f89..3db215dba 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -37,8 +37,8 @@ from tornado import httpclient, httputil from tornado.ioloop import IOLoop from tornado.iostream import StreamClosedError from tornado.log import gen_log, app_log -from tornado.netutil import Resolver from tornado import simple_httpclient +from tornado.tcpclient import TCPClient from tornado.util import bytes_type, unicode_type try: @@ -821,10 +821,10 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): 'Sec-WebSocket-Version': '13', }) - self.resolver = Resolver(io_loop=io_loop) + self.tcp_client = TCPClient(io_loop=io_loop) super(WebSocketClientConnection, self).__init__( io_loop, None, request, lambda: None, self._on_http_response, - 104857600, self.resolver, 65536) + 104857600, self.tcp_client, 65536) def close(self, code=None, reason=None): """Closes the websocket connection.