]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Refactor IOStream reading logic to fix problems with SSL sockets.
authorBen Darnell <ben@bendarnell.com>
Thu, 16 Sep 2010 20:45:13 +0000 (13:45 -0700)
committerBen Darnell <ben@bendarnell.com>
Thu, 16 Sep 2010 21:29:23 +0000 (14:29 -0700)
Now reads from the socket until it would block.  This ensures that we
see data in SSLSocket's internal buffers, which are otherwise invisible
to functions like select().

Added a test that makes a large POST request to an SSL server, which was
the scenario that originally exposed this bug.

tornado/iostream.py
tornado/test/httpserver_test.py

index 91ac29d7ed4604e6b236f4fa96a688c62746d924..3e76af12f3da1ea9b504bb3b1e622c5dfa30975a 100644 (file)
@@ -85,24 +85,28 @@ class IOStream(object):
     def read_until(self, delimiter, callback):
         """Call callback when we read the given delimiter."""
         assert not self._read_callback, "Already reading"
-        loc = self._read_buffer.find(delimiter)
-        if loc != -1:
-            self._run_callback(callback, self._consume(loc + len(delimiter)))
-            return
-        self._check_closed()
         self._read_delimiter = delimiter
         self._read_callback = callback
+        while True:
+            # See if we've already got the data from a previous read
+            if self._read_from_buffer():
+                return
+            self._check_closed()
+            if self._read_to_buffer() == 0:
+                break
         self._add_io_state(self.io_loop.READ)
 
     def read_bytes(self, num_bytes, callback):
         """Call callback when we read the given number of bytes."""
         assert not self._read_callback, "Already reading"
-        if len(self._read_buffer) >= num_bytes:
-            callback(self._consume(num_bytes))
-            return
-        self._check_closed()
         self._read_bytes = num_bytes
         self._read_callback = callback
+        while True:
+            if self._read_from_buffer():
+                return
+            self._check_closed()
+            if self._read_to_buffer() == 0:
+                break
         self._add_io_state(self.io_loop.READ)
 
     def write(self, data, callback=None):
@@ -180,24 +184,67 @@ class IOStream(object):
             raise
 
     def _handle_read(self):
+        while True:
+            try:
+                # Read from the socket until we get EWOULDBLOCK or equivalent.
+                # SSL sockets do some internal buffering, and if the data is
+                # sitting in the SSL object's buffer select() and friends
+                # can't see it; the only way to find out if it's there is to
+                # try to read it.
+                result = self._read_to_buffer()
+            except Exception:
+                self.close()
+                return
+            if result == 0:
+                break
+            else:
+                if self._read_from_buffer():
+                    return
+
+    def _read_from_socket(self):
+        """Attempts to read from the socket.
+
+        Returns the data read or None if there is nothing to read.
+        May be overridden in subclasses.
+        """
         try:
             chunk = self.socket.recv(self.read_chunk_size)
         except socket.error, e:
             if e.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
-                return
+                return None
             else:
-                logging.warning("Read error on %d: %s",
-                                self.socket.fileno(), e)
-                self.close()
-                return
-        if not chunk:
+                raise
+        return chunk
+
+    def _read_to_buffer(self):
+        """Reads from the socket and appends the result to the read buffer.
+
+        Returns the number of bytes read.  Returns 0 if there is nothing
+        to read (i.e. the read returns EWOULDBLOCK or equivalent).  On
+        error closes the socket and raises an exception.
+        """
+        try:
+            chunk = self._read_from_socket()
+        except socket.error, e:
+            # ssl.SSLError is a subclass of socket.error
+            logging.warning("Read error on %d: %s",
+                            self.socket.fileno(), e)
             self.close()
-            return
+            raise
+        if chunk is None:
+            return 0
         self._read_buffer += chunk
         if len(self._read_buffer) >= self.max_buffer_size:
             logging.error("Reached maximum read buffer size")
             self.close()
-            return
+            raise IOError("Reached maximum read buffer size")
+        return len(chunk)
+
+    def _read_from_buffer(self):
+        """Attempts to complete the currently-pending read from the buffer.
+
+        Returns True if the read was completed.
+        """
         if self._read_bytes:
             if len(self._read_buffer) >= self._read_bytes:
                 num_bytes = self._read_bytes
@@ -205,6 +252,7 @@ class IOStream(object):
                 self._read_callback = None
                 self._read_bytes = None
                 self._run_callback(callback, self._consume(num_bytes))
+                return True
         elif self._read_delimiter:
             loc = self._read_buffer.find(self._read_delimiter)
             if loc != -1:
@@ -214,6 +262,8 @@ class IOStream(object):
                 self._read_delimiter = None
                 self._run_callback(callback,
                                    self._consume(loc + delimiter_len))
+                return True
+        return False
 
     def _handle_write(self):
         while self._write_buffer:
@@ -287,3 +337,25 @@ class SSLIOStream(IOStream):
             self._do_ssl_handshake()
             return
         super(SSLIOStream, self)._handle_write()
+
+    def _read_from_socket(self):
+        try:
+            # SSLSocket objects have both a read() and recv() method,
+            # while regular sockets only have recv().
+            # The recv() method blocks (at least in python 2.6) if it is
+            # called when there is nothing to read, so we have to use
+            # read() instead.
+            chunk = self.socket.read(self.read_chunk_size)
+        except ssl.SSLError, e:
+            # SSLError is a subclass of socket.error, so this except
+            # block must come first.
+            if e.args[0] == ssl.SSL_ERROR_WANT_READ:
+                return None
+            else:
+                raise
+        except socket.error, e:
+            if e.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
+                return None
+            else:
+                raise
+        return chunk
index 3d6de203f23984b767a29f5c915ae9b6745f7631..6174d0afabb8f8f2121192bc991c80c0c97f61b9 100644 (file)
@@ -17,6 +17,9 @@ class HelloWorldRequestHandler(RequestHandler):
     def get(self):
         self.finish("Hello world")
 
+    def post(self):
+        self.finish("Got %d bytes in POST" % len(self.request.body))
+
 class SSLTest(AsyncHTTPTestCase, LogTrapTestCase):
     def get_app(self):
         return Application([('/', HelloWorldRequestHandler)])
@@ -29,16 +32,26 @@ class SSLTest(AsyncHTTPTestCase, LogTrapTestCase):
                 certfile=os.path.join(test_dir, 'test.crt'),
                 keyfile=os.path.join(test_dir, 'test.key')))
 
-    def test_ssl(self):
+    def fetch(self, path, **kwargs):
         def disable_cert_check(curl):
             # Our certificate was not signed by a CA, so don't check it
             curl.setopt(pycurl.SSL_VERIFYPEER, 0)
-        self.http_client.fetch(self.get_url('/').replace('http', 'https'),
+        self.http_client.fetch(self.get_url(path).replace('http', 'https'),
                                self.stop,
-                               prepare_curl_callback=disable_cert_check)
-        response = self.wait()
+                               prepare_curl_callback=disable_cert_check,
+                               **kwargs)
+        return self.wait()
+
+    def test_ssl(self):
+        response = self.fetch('/')
         self.assertEqual(response.body, "Hello world")
 
+    def test_large_post(self):
+        response = self.fetch('/',
+                              method='POST',
+                              body='A'*5000)
+        self.assertEqual(response.body, "Got 5000 bytes in POST")
+
 if (ssl is None or
     (pycurl.version_info()[5].startswith('GnuTLS') and
      pycurl.version_info()[2] < 0x71400)):