request_delegate = delegate.start_request(self, conn)
try:
ret = yield conn.read_response(request_delegate)
- except iostream.StreamClosedError:
+ except (iostream.StreamClosedError,
+ iostream.UnsatisfiableReadError):
return
except Exception:
# TODO: this is probably too broad; it would be better to
pass
+
class UnsatisfiableReadError(Exception):
"""Exception raised when a read cannot be satisfied.
futures.append(self._connect_future)
self._connect_future = None
for future in futures:
- future.set_exception(StreamClosedError())
+ future.set_exception(self.error or StreamClosedError())
if self._close_callback is not None:
cb = self._close_callback
self._close_callback = None
gen_log.warning("Got events for closed stream %s", fd)
return
try:
+ if self._connecting:
+ # Most IOLoops will report a write failed connect
+ # with the WRITE event, but SelectIOLoop reports a
+ # READ as well so we must check for connecting before
+ # either.
+ self._handle_connect()
+ if self.closed():
+ return
if events & self.io_loop.READ:
self._handle_read()
if self.closed():
return
if events & self.io_loop.WRITE:
- if self._connecting:
- self._handle_connect()
self._handle_write()
if self.closed():
return
from tornado.netutil import Resolver, OverrideResolver
from tornado.log import gen_log
from tornado import stack_context
+from tornado.tcpclient import TCPClient
import base64
import collections
self.waiting = {}
self.max_buffer_size = max_buffer_size
self.max_header_size = max_header_size
+ # TCPClient could create a Resolver for us, but we have to do it
+ # ourselves to support hostname_mapping.
if resolver:
self.resolver = resolver
self.own_resolver = False
if hostname_mapping is not None:
self.resolver = OverrideResolver(resolver=self.resolver,
mapping=hostname_mapping)
+ self.tcp_client = TCPClient(resolver=self.resolver, io_loop=io_loop)
def close(self):
super(SimpleAsyncHTTPClient, self).close()
if self.own_resolver:
self.resolver.close()
+ self.tcp_client.close()
def fetch_impl(self, request, callback):
key = object()
def _handle_request(self, request, release_callback, final_callback):
_HTTPConnection(self.io_loop, self, request, release_callback,
- final_callback, self.max_buffer_size, self.resolver,
+ final_callback, self.max_buffer_size, self.tcp_client,
self.max_header_size)
def _release_fetch(self, key):
_SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
def __init__(self, io_loop, client, request, release_callback,
- final_callback, max_buffer_size, resolver,
+ final_callback, max_buffer_size, tcp_client,
max_header_size):
self.start_time = io_loop.time()
self.io_loop = io_loop
self.release_callback = release_callback
self.final_callback = final_callback
self.max_buffer_size = max_buffer_size
- self.resolver = resolver
+ self.tcp_client = tcp_client
self.max_header_size = max_header_size
self.code = None
self.headers = None
# so restrict to ipv4 by default.
af = socket.AF_INET
+ ssl_options = self._get_ssl_options(self.parsed.scheme)
+
timeout = min(self.request.connect_timeout, self.request.request_timeout)
if timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + timeout,
stack_context.wrap(self._on_timeout))
- self.resolver.resolve(host, port, af, callback=self._on_resolve)
+ self.tcp_client.connect(host, port, af=af,
+ ssl_options=ssl_options,
+ callback=self._on_connect)
- def _on_resolve(self, addrinfo):
- if self.final_callback is None:
- # final_callback is cleared if we've hit our timeout
- return
- self.stream = self._create_stream(addrinfo)
- self.stream.set_close_callback(self._on_close)
- # ipv6 addresses are broken (in self.parsed.hostname) until
- # 2.7, here is correctly parsed value calculated in __init__
- self._sockaddr = addrinfo[0][1]
- self.stream.connect(self._sockaddr, self._on_connect,
- server_hostname=self.parsed_hostname)
-
- def _create_stream(self, addrinfo):
- af = addrinfo[0][0]
- if self.parsed.scheme == "https":
+ def _get_ssl_options(self, scheme):
+ if scheme == "https":
ssl_options = {}
if self.request.validate_cert:
ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
# of openssl, but python 2.6 doesn't expose version
# information.
ssl_options["ssl_version"] = ssl.PROTOCOL_TLSv1
-
- return SSLIOStream(socket.socket(af),
- io_loop=self.io_loop,
- ssl_options=ssl_options,
- max_buffer_size=self.max_buffer_size)
- else:
- return IOStream(socket.socket(af),
- io_loop=self.io_loop,
- max_buffer_size=self.max_buffer_size)
+ return ssl_options
+ return None
def _on_timeout(self):
self._timeout = None
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
- def _on_connect(self):
+ def _on_connect(self, stream):
+ if self.final_callback is None:
+ # final_callback is cleared if we've hit our timeout.
+ stream.close()
+ return
+ self.stream = stream
+ self.stream.set_close_callback(self._on_close)
self._remove_timeout()
if self.final_callback is None:
return
--- /dev/null
+#!/usr/bin/env python
+#
+# Copyright 2014 Facebook
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""A non-blocking TCP connection factory.
+"""
+from __future__ import absolute_import, division, print_function, with_statement
+
+import socket
+
+from tornado.ioloop import IOLoop
+from tornado.iostream import IOStream, SSLIOStream
+from tornado import gen
+from tornado.netutil import Resolver
+
+class TCPClient(object):
+ """A non-blocking TCP connection factory.
+ """
+ def __init__(self, resolver=None, io_loop=None):
+ self.io_loop = io_loop or IOLoop.current()
+ if resolver is not None:
+ self.resolver = resolver
+ self._own_resolver = False
+ else:
+ self.resolver = Resolver(io_loop=io_loop)
+ self._own_resolver = True
+
+ def close(self):
+ if self._own_resolver:
+ self.resolver.close()
+
+ @gen.coroutine
+ def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
+ max_buffer_size=None):
+ """Connect to the given host and port.
+
+ Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
+ ``ssl_options`` is not None).
+ """
+ addrinfo = yield self.resolver.resolve(host, port, af)
+ af, addr = addrinfo[0]
+ stream = self._create_stream(af, ssl_options, max_buffer_size)
+ yield stream.connect(addr, server_hostname=host)
+ raise gen.Return(stream)
+
+ def _create_stream(self, af, ssl_options, max_buffer_size):
+ if ssl_options is None:
+ return IOStream(socket.socket(af),
+ io_loop=self.io_loop,
+ max_buffer_size=max_buffer_size)
+ else:
+ return SSLIOStream(socket.socket(af),
+ io_loop=self.io_loop,
+ ssl_options=ssl_options,
+ max_buffer_size=max_buffer_size)
from tornado.httpserver import HTTPServer
from tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine
from tornado.iostream import IOStream
-from tornado.log import gen_log
+from tornado.log import gen_log, app_log
from tornado.netutil import ssl_options_to_context
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test
# connection, rather than waiting for a timeout or otherwise
# misbehaving.
with ExpectLog(gen_log, '(SSL Error|uncaught exception)'):
- self.http_client.fetch(self.get_url("/").replace('https:', 'http:'),
- self.stop,
- request_timeout=3600,
- connect_timeout=3600)
- response = self.wait()
+ # TODO: this should go to gen_log, not app_log. See TODO
+ # in http1connection.py (_server_request_loop)
+ with ExpectLog(app_log, 'Uncaught exception', required=False):
+ self.http_client.fetch(
+ self.get_url("/").replace('https:', 'http:'),
+ self.stop,
+ request_timeout=3600,
+ connect_timeout=3600)
+ response = self.wait()
self.assertEqual(response.code, 599)
# Python's SSL implementation differs significantly between versions.
'tornado.test.process_test',
'tornado.test.simple_httpclient_test',
'tornado.test.stack_context_test',
+ 'tornado.test.tcpclient_test',
'tornado.test.template_test',
'tornado.test.testing_test',
'tornado.test.twisted_test',
--- /dev/null
+#!/usr/bin/env python
+#
+# Copyright 2014 Facebook
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from __future__ import absolute_import, division, print_function, with_statement
+
+from contextlib import closing
+
+from tornado.log import gen_log
+from tornado.tcpclient import TCPClient
+from tornado.tcpserver import TCPServer
+from tornado.testing import AsyncTestCase, bind_unused_port, gen_test, ExpectLog
+
+class TestTCPServer(TCPServer):
+ def __init__(self):
+ super(TestTCPServer, self).__init__()
+ self.streams = []
+ socket, self.port = bind_unused_port()
+ self.add_socket(socket)
+
+ def handle_stream(self, stream, address):
+ self.streams.append(stream)
+
+ def stop(self):
+ super(TestTCPServer, self).stop()
+ for stream in self.streams:
+ stream.close()
+
+class TCPClientTest(AsyncTestCase):
+ def setUp(self):
+ super(TCPClientTest, self).setUp()
+ self.server = TestTCPServer()
+ self.port = self.server.port
+ self.client = TCPClient()
+
+ def tearDown(self):
+ self.client.close()
+ self.server.stop()
+ super(TCPClientTest, self).tearDown()
+
+ @gen_test
+ def test_connect_ipv4(self):
+ stream = yield self.client.connect('127.0.0.1', self.port)
+ with closing(stream):
+ stream.write(b"hello")
+ data = yield self.server.streams[0].read_bytes(5)
+ self.assertEqual(data, b"hello")
+
+ @gen_test
+ def test_refused_ipv4(self):
+ sock, port = bind_unused_port()
+ sock.close()
+ with ExpectLog(gen_log, 'Connect error'):
+ with self.assertRaises(IOError):
+ yield self.client.connect('127.0.0.1', port)
from tornado.ioloop import IOLoop
from tornado.iostream import StreamClosedError
from tornado.log import gen_log, app_log
-from tornado.netutil import Resolver
from tornado import simple_httpclient
+from tornado.tcpclient import TCPClient
from tornado.util import bytes_type, unicode_type
try:
'Sec-WebSocket-Version': '13',
})
- self.resolver = Resolver(io_loop=io_loop)
+ self.tcp_client = TCPClient(io_loop=io_loop)
super(WebSocketClientConnection, self).__init__(
io_loop, None, request, lambda: None, self._on_http_response,
- 104857600, self.resolver, 65536)
+ 104857600, self.tcp_client, 65536)
def close(self, code=None, reason=None):
"""Closes the websocket connection.