From: Ben Darnell Date: Thu, 16 Sep 2010 20:45:13 +0000 (-0700) Subject: Refactor IOStream reading logic to fix problems with SSL sockets. X-Git-Tag: v1.2.0~128 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f6f7f83719a48083f3e5ed999e19c0ec67de9b9b;p=thirdparty%2Ftornado.git Refactor IOStream reading logic to fix problems with SSL sockets. Now reads from the socket until it would block. This ensures that we see data in SSLSocket's internal buffers, which are otherwise invisible to functions like select(). Added a test that makes a large POST request to an SSL server, which was the scenario that originally exposed this bug. --- diff --git a/tornado/iostream.py b/tornado/iostream.py index 91ac29d7e..3e76af12f 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -85,24 +85,28 @@ class IOStream(object): def read_until(self, delimiter, callback): """Call callback when we read the given delimiter.""" assert not self._read_callback, "Already reading" - loc = self._read_buffer.find(delimiter) - if loc != -1: - self._run_callback(callback, self._consume(loc + len(delimiter))) - return - self._check_closed() self._read_delimiter = delimiter self._read_callback = callback + while True: + # See if we've already got the data from a previous read + if self._read_from_buffer(): + return + self._check_closed() + if self._read_to_buffer() == 0: + break self._add_io_state(self.io_loop.READ) def read_bytes(self, num_bytes, callback): """Call callback when we read the given number of bytes.""" assert not self._read_callback, "Already reading" - if len(self._read_buffer) >= num_bytes: - callback(self._consume(num_bytes)) - return - self._check_closed() self._read_bytes = num_bytes self._read_callback = callback + while True: + if self._read_from_buffer(): + return + self._check_closed() + if self._read_to_buffer() == 0: + break self._add_io_state(self.io_loop.READ) def write(self, data, callback=None): @@ -180,24 +184,67 @@ class IOStream(object): raise def _handle_read(self): + while True: + try: + # Read from the socket until we get EWOULDBLOCK or equivalent. + # SSL sockets do some internal buffering, and if the data is + # sitting in the SSL object's buffer select() and friends + # can't see it; the only way to find out if it's there is to + # try to read it. + result = self._read_to_buffer() + except Exception: + self.close() + return + if result == 0: + break + else: + if self._read_from_buffer(): + return + + def _read_from_socket(self): + """Attempts to read from the socket. + + Returns the data read or None if there is nothing to read. + May be overridden in subclasses. + """ try: chunk = self.socket.recv(self.read_chunk_size) except socket.error, e: if e.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): - return + return None else: - logging.warning("Read error on %d: %s", - self.socket.fileno(), e) - self.close() - return - if not chunk: + raise + return chunk + + def _read_to_buffer(self): + """Reads from the socket and appends the result to the read buffer. + + Returns the number of bytes read. Returns 0 if there is nothing + to read (i.e. the read returns EWOULDBLOCK or equivalent). On + error closes the socket and raises an exception. + """ + try: + chunk = self._read_from_socket() + except socket.error, e: + # ssl.SSLError is a subclass of socket.error + logging.warning("Read error on %d: %s", + self.socket.fileno(), e) self.close() - return + raise + if chunk is None: + return 0 self._read_buffer += chunk if len(self._read_buffer) >= self.max_buffer_size: logging.error("Reached maximum read buffer size") self.close() - return + raise IOError("Reached maximum read buffer size") + return len(chunk) + + def _read_from_buffer(self): + """Attempts to complete the currently-pending read from the buffer. + + Returns True if the read was completed. + """ if self._read_bytes: if len(self._read_buffer) >= self._read_bytes: num_bytes = self._read_bytes @@ -205,6 +252,7 @@ class IOStream(object): self._read_callback = None self._read_bytes = None self._run_callback(callback, self._consume(num_bytes)) + return True elif self._read_delimiter: loc = self._read_buffer.find(self._read_delimiter) if loc != -1: @@ -214,6 +262,8 @@ class IOStream(object): self._read_delimiter = None self._run_callback(callback, self._consume(loc + delimiter_len)) + return True + return False def _handle_write(self): while self._write_buffer: @@ -287,3 +337,25 @@ class SSLIOStream(IOStream): self._do_ssl_handshake() return super(SSLIOStream, self)._handle_write() + + def _read_from_socket(self): + try: + # SSLSocket objects have both a read() and recv() method, + # while regular sockets only have recv(). + # The recv() method blocks (at least in python 2.6) if it is + # called when there is nothing to read, so we have to use + # read() instead. + chunk = self.socket.read(self.read_chunk_size) + except ssl.SSLError, e: + # SSLError is a subclass of socket.error, so this except + # block must come first. + if e.args[0] == ssl.SSL_ERROR_WANT_READ: + return None + else: + raise + except socket.error, e: + if e.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): + return None + else: + raise + return chunk diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index 3d6de203f..6174d0afa 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -17,6 +17,9 @@ class HelloWorldRequestHandler(RequestHandler): def get(self): self.finish("Hello world") + def post(self): + self.finish("Got %d bytes in POST" % len(self.request.body)) + class SSLTest(AsyncHTTPTestCase, LogTrapTestCase): def get_app(self): return Application([('/', HelloWorldRequestHandler)]) @@ -29,16 +32,26 @@ class SSLTest(AsyncHTTPTestCase, LogTrapTestCase): certfile=os.path.join(test_dir, 'test.crt'), keyfile=os.path.join(test_dir, 'test.key'))) - def test_ssl(self): + def fetch(self, path, **kwargs): def disable_cert_check(curl): # Our certificate was not signed by a CA, so don't check it curl.setopt(pycurl.SSL_VERIFYPEER, 0) - self.http_client.fetch(self.get_url('/').replace('http', 'https'), + self.http_client.fetch(self.get_url(path).replace('http', 'https'), self.stop, - prepare_curl_callback=disable_cert_check) - response = self.wait() + prepare_curl_callback=disable_cert_check, + **kwargs) + return self.wait() + + def test_ssl(self): + response = self.fetch('/') self.assertEqual(response.body, "Hello world") + def test_large_post(self): + response = self.fetch('/', + method='POST', + body='A'*5000) + self.assertEqual(response.body, "Got 5000 bytes in POST") + if (ssl is None or (pycurl.version_info()[5].startswith('GnuTLS') and pycurl.version_info()[2] < 0x71400)):