]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add initial version of non-curl-based AsyncHTTPClient.
authorBen Darnell <ben@bendarnell.com>
Mon, 11 Oct 2010 21:16:52 +0000 (14:16 -0700)
committerBen Darnell <ben@bendarnell.com>
Mon, 11 Oct 2010 21:16:52 +0000 (14:16 -0700)
tornado/httpclient.py
tornado/simple_httpclient.py [new file with mode: 0644]
tornado/test/runtests.py
tornado/test/simple_httpclient_test.py [new file with mode: 0644]

index 26e447cc90348a959425bf81ca3aa98e4e0eb0f7..0c5d8e68b9a1eda2646711323eecc2ff97480aa0 100644 (file)
@@ -25,6 +25,7 @@ import email.utils
 import errno
 import httplib
 import logging
+import os
 import pycurl
 import sys
 import threading
@@ -649,5 +650,13 @@ def main():
         if options.print_body:
             print response.body
 
+# If the environment variable USE_SIMPLE_HTTPCLIENT is set to a non-empty
+# string, use SimpleAsyncHTTPClient instead of AsyncHTTPClient.
+# This is provided as a convenience for testing SimpleAsyncHTTPClient,
+# and may be removed or replaced with a better way of specifying the preferred
+# HTTPClient implementation before the next release.
+if os.environ.get('USE_SIMPLE_HTTPCLIENT'):
+    from tornado.simple_httpclient import SimpleAsyncHTTPClient as AsyncHTTPClient
+
 if __name__ == "__main__":
     main()
diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py
new file mode 100644 (file)
index 0000000..b155a32
--- /dev/null
@@ -0,0 +1,202 @@
+#!/usr/bin/env python
+from __future__ import with_statement
+
+from cStringIO import StringIO
+from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError
+from tornado.httputil import HTTPHeaders
+from tornado.ioloop import IOLoop
+from tornado.iostream import IOStream, SSLIOStream
+from tornado import stack_context
+
+import contextlib
+import errno
+import functools
+import logging
+import re
+import socket
+import urlparse
+
+try:
+    import ssl # python 2.6+
+except ImportError:
+    ssl = None
+
+class SimpleAsyncHTTPClient(object):
+    """Non-blocking HTTP client with no external dependencies.
+
+    WARNING:  This class is still in development and not yet recommended
+    for production use.
+
+    This class implements an HTTP 1.1 client on top of Tornado's IOStreams.
+    It does not currently implement all applicable parts of the HTTP
+    specification, but it does enough to work with major web service APIs
+    (mostly tested against the Twitter API so far).
+
+    Many features found in the curl-based AsyncHTTPClient are not yet
+    implemented.  The currently-supported set of parameters to HTTPRequest
+    are url, method, headers, body, streaming_callback, and header_callback.
+    Connections are not reused, and no attempt is made to limit the number
+    of outstanding requests.
+
+    Python 2.6 or higher is required for HTTPS support.  Users of Python 2.5
+    should use the curl-based AsyncHTTPClient if HTTPS support is required.
+    """
+    # TODO: singleton magic?
+    def __init__(self, io_loop=None):
+        self.io_loop = io_loop or IOLoop.instance()
+
+    def close(self):
+        pass
+
+    def fetch(self, request, callback, **kwargs):
+        if not isinstance(request, HTTPRequest):
+            request = HTTPRequest(url=request, **kwargs)
+        if not isinstance(request.headers, HTTPHeaders):
+            request.headers = HTTPHeaders(request.headers)
+        callback = stack_context.wrap(callback)
+        _HTTPConnection(self.io_loop, request, callback)
+
+
+
+class _HTTPConnection(object):
+    def __init__(self, io_loop, request, callback):
+        self.io_loop = io_loop
+        self.request = request
+        self.callback = callback
+        self.code = None
+        self.headers = None
+        self.chunks = None
+        with stack_context.StackContext(self.cleanup):
+            parsed = urlparse.urlsplit(self.request.url)
+            sock = socket.socket()
+            sock.setblocking(False)
+            if ":" in parsed.netloc:
+                host, _, port = parsed.netloc.partition(":")
+                port = int(port)
+            else:
+                host = parsed.netloc
+                port = 443 if parsed.scheme == "https" else 80
+            try:
+                sock.connect((host, port))
+            except socket.error, e:
+                # In non-blocking mode connect() always raises EINPROGRESS
+                if e.errno != errno.EINPROGRESS:
+                    raise
+            # Wait for the non-blocking connect to complete
+            self.io_loop.add_handler(sock.fileno(),
+                                     functools.partial(self._on_connect,
+                                                       sock, parsed),
+                                     IOLoop.WRITE)
+
+    def _on_connect(self, sock, parsed, fd, events):
+        self.io_loop.remove_handler(fd)
+        if parsed.scheme == "https":
+            # TODO: cert verification, etc
+            sock = ssl.wrap_socket(sock, do_handshake_on_connect=False)
+            self.stream = SSLIOStream(sock, io_loop=self.io_loop)
+        else:
+            self.stream = IOStream(sock, io_loop=self.io_loop)
+        if "Host" not in self.request.headers:
+            self.request.headers["Host"] = parsed.netloc
+        has_body = self.request.method in ("POST", "PUT")
+        if has_body:
+            assert self.request.body is not None
+            self.request.headers["Content-Length"] = len(
+                self.request.body)
+        else:
+            assert self.request.body is None
+        if (self.request.method == "POST" and
+            "Content-Type" not in self.request.headers):
+            self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
+        req_path = ((parsed.path or '/') +
+                (('?' + parsed.query) if parsed.query else ''))
+        request_lines = ["%s %s HTTP/1.1" % (self.request.method,
+                                             req_path)]
+        for k, v in self.request.headers.get_all():
+            request_lines.append("%s: %s" % (k, v))
+        if logging.getLogger().isEnabledFor(logging.DEBUG):
+            for line in request_lines:
+                logging.debug(line)
+        self.stream.write("\r\n".join(request_lines) + "\r\n\r\n")
+        if has_body:
+            self.stream.write(self.request.body)
+        self.stream.read_until("\r\n\r\n", self._on_headers)
+
+    @contextlib.contextmanager
+    def cleanup(self):
+        try:
+            yield
+        except Exception, e:
+            logging.warning("uncaught exception", exc_info=True)
+            if self.callback is not None:
+                self.callback(HTTPResponse(self.request, 599, error=e))
+                self.callback = None
+
+    def _on_headers(self, data):
+        logging.debug(data)
+        first_line, _, header_data = data.partition("\r\n")
+        match = re.match("HTTP/1.[01] ([0-9]+) .*", first_line)
+        assert match
+        self.code = int(match.group(1))
+        self.headers = HTTPHeaders.parse(header_data)
+        if self.request.header_callback is not None:
+            for k, v in self.headers.get_all():
+                self.request.header_callback("%s: %s\r\n" % (k, v))
+        if self.headers.get("Transfer-Encoding") == "chunked":
+            self.chunks = []
+            self.stream.read_until("\r\n", self._on_chunk_length)
+        elif "Content-Length" in self.headers:
+            self.stream.read_bytes(int(self.headers["Content-Length"]),
+                                   self._on_body)
+        else:
+            raise Exception("No Content-length or chunked encoding, "
+                            "don't know how to read %s", self.request.url)
+
+    def _on_body(self, data):
+        if self.request.streaming_callback:
+            if self.chunks is None:
+                # if chunks is not None, we already called streaming_callback
+                # in _on_chunk_data
+                self.request.streaming_callback(data)
+            buffer = StringIO()
+        else:
+            buffer = StringIO(data) # TODO: don't require one big string?
+        response = HTTPResponse(self.request, self.code, headers=self.headers,
+                                buffer=buffer)
+        self.callback(response)
+        self.callback = None
+
+    def _on_chunk_length(self, data):
+        # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
+        length = int(data.strip(), 16)
+        if length == 0:
+            self._on_body(''.join(self.chunks))
+        else:
+            self.stream.read_bytes(length + 2,  # chunk ends with \r\n
+                              self._on_chunk_data)
+
+    def _on_chunk_data(self, data):
+        assert data[-2:] == "\r\n"
+        chunk = data[:-2]
+        if self.request.streaming_callback is not None:
+            self.request.streaming_callback(chunk)
+        else:
+            self.chunks.append(chunk)
+        self.stream.read_until("\r\n", self._on_chunk_length)
+
+
+def main():
+    from tornado.options import define, options, parse_command_line
+    args = parse_command_line()
+    client = SimpleAsyncHTTPClient()
+    io_loop = IOLoop.instance()
+    for arg in args:
+        def callback(response):
+            io_loop.stop()
+            response.rethrow()
+            print response.body
+        client.fetch(arg, callback)
+        io_loop.start()
+
+if __name__ == "__main__":
+    main()
index f66199318ee75afdaa298efafeff1ea6af434f6c..1dedb6d1525128e4c1a05be70bbba3169ab08907 100755 (executable)
@@ -6,6 +6,7 @@ TEST_MODULES = [
     'tornado.test.httpserver_test',
     'tornado.test.ioloop_test',
     'tornado.test.iostream_test',
+    'tornado.test.simple_httpclient_test',
     'tornado.test.stack_context_test',
     'tornado.test.testing_test',
     'tornado.test.web_test',
diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py
new file mode 100644 (file)
index 0000000..df40ec1
--- /dev/null
@@ -0,0 +1,74 @@
+#!/usr/bin/env python
+
+from tornado.simple_httpclient import SimpleAsyncHTTPClient
+from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase
+from tornado.web import Application, RequestHandler
+
+class HelloWorldHandler(RequestHandler):
+  def get(self):
+    name = self.get_argument("name", "world")
+    self.set_header("Content-Type", "text/plain")
+    self.finish("Hello %s!" % name)
+
+class PostHandler(RequestHandler):
+  def post(self):
+    self.finish("Post arg1: %s, arg2: %s" % (
+        self.get_argument("arg1"), self.get_argument("arg2")))
+
+class ChunkHandler(RequestHandler):
+  def get(self):
+    self.write("asdf")
+    self.flush()
+    self.write("qwer")
+
+class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
+  def fetch(self, url, **kwargs):
+    self.http_client.fetch(url, self.stop, **kwargs)
+    return self.wait()
+
+  def get_app(self):
+    return Application([
+        ("/hello", HelloWorldHandler),
+        ("/post", PostHandler),
+        ("/chunk", ChunkHandler),
+        ])
+
+  def setUp(self):
+    super(SimpleHTTPClientTestCase, self).setUp()
+    # replace the client defined in the parent class
+    self.http_client = SimpleAsyncHTTPClient(io_loop=self.io_loop)
+
+  def test_hello_world(self):
+    response = self.fetch(self.get_url("/hello"))
+    self.assertEqual(response.code, 200)
+    self.assertEqual(response.headers["Content-Type"], "text/plain")
+    self.assertEqual(response.body, "Hello world!")
+
+    response = self.fetch(self.get_url("/hello?name=Ben"))
+    self.assertEqual(response.body, "Hello Ben!")
+
+  def test_streaming_callback(self):
+    # streaming_callback is also tested in test_chunked
+    chunks = []
+    response = self.fetch(self.get_url("/hello"),
+                          streaming_callback=chunks.append)
+    # with streaming_callback, data goes to the callback and not response.body
+    self.assertEqual(chunks, ["Hello world!"])
+    self.assertFalse(response.body)
+
+  def test_post(self):
+    response = self.fetch(self.get_url("/post"), method="POST",
+                          body="arg1=foo&arg2=bar")
+    self.assertEqual(response.code, 200)
+    self.assertEqual(response.body, "Post arg1: foo, arg2: bar")
+
+  def test_chunked(self):
+    response = self.fetch(self.get_url("/chunk"))
+    self.assertEqual(response.body, "asdfqwer")
+
+    chunks = []
+    response = self.fetch(self.get_url("/chunk"),
+                          streaming_callback=chunks.append)
+    self.assertEqual(chunks, ["asdf", "qwer"])
+    self.assertFalse(response.body)
+