self.io_loop.remove_handler(self.socket.fileno())
self.socket.close()
self.socket = None
- if self._close_callback:
- self._run_callback(self._close_callback)
+ if self._close_callback and self._pending_callbacks == 0:
+ # if there are pending callbacks, don't run the close callback
+ # until they're done (see _maybe_add_error_handler)
+ cb = self._close_callback
+ self._close_callback = None
+ self._run_callback(cb)
def reading(self):
"""Returns true if we are currently reading from the stream."""
def _maybe_add_error_listener(self):
if self._state is None and self._pending_callbacks == 0:
- self._add_io_state(0)
+ if self.socket is None:
+ cb = self._close_callback
+ if cb is not None:
+ self._close_callback = None
+ self._run_callback(cb)
+ else:
+ self._add_io_state(0)
def _add_io_state(self, state):
"""Adds `state` (IOLoop.{READ,WRITE} flags) to our event handler.
from tornado.escape import utf8
from tornado.httpclient import AsyncHTTPClient
-from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase
+from tornado.iostream import IOStream
+from tornado import netutil
+from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port
from tornado.util import b, bytes_type
from tornado.web import Application, RequestHandler, url
self.assertEqual(chunks, [b("asdf"), b("qwer")])
self.assertFalse(response.body)
+ def test_chunked_close(self):
+ # test case in which chunks spread read-callback processing
+ # over several ioloop iterations, but the connection is already closed.
+ port = get_unused_port()
+ (sock,) = netutil.bind_sockets(port, address="127.0.0.1")
+ def accept_callback(conn, address):
+ # fake an HTTP server using chunked encoding where the final chunks
+ # and connection close all happen at once
+ stream = IOStream(conn, io_loop=self.io_loop)
+ stream.write(b("""\
+HTTP/1.1 200 OK
+Transfer-Encoding: chunked
+
+1
+1
+1
+2
+0
+
+""").replace(b("\n"), b("\r\n")), callback=stream.close)
+ netutil.add_accept_handler(sock, accept_callback, self.io_loop)
+ self.http_client.fetch("http://127.0.0.1:%d/" % port, self.stop)
+ resp = self.wait()
+ resp.rethrow()
+ self.assertEqual(resp.body, b("12"))
+
+
def test_basic_auth(self):
self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame").body,