]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
iostream: Add type annotations
authorBen Darnell <ben@bendarnell.com>
Sun, 12 Aug 2018 16:26:00 +0000 (12:26 -0400)
committerBen Darnell <ben@bendarnell.com>
Mon, 10 Sep 2018 04:20:09 +0000 (00:20 -0400)
setup.cfg
tornado/iostream.py
tornado/test/iostream_test.py

index cde5d24f27b7b9f0b701a08bc815c2ca033cb37e..c843ce7b3f16b06749b6b60fb0537d292b1b93f2 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -22,6 +22,9 @@ disallow_untyped_defs = True
 [mypy-tornado.ioloop]
 disallow_untyped_defs = True
 
+[mypy-tornado.iostream]
+disallow_untyped_defs = True
+
 [mypy-tornado.locale]
 disallow_untyped_defs = True
 
@@ -63,6 +66,9 @@ check_untyped_defs = True
 [mypy-tornado.test.ioloop_test]
 check_untyped_defs = True
 
+[mypy-tornado.test.iostream_test]
+check_untyped_defs = True
+
 [mypy-tornado.test.locale_test]
 check_untyped_defs = True
 
index 514c36f503cedaa37d6c7b9040a6ce1f5eaa015a..fe41cebafa7b7d05b89b7c87dbb60cdd329f2782 100644 (file)
@@ -39,6 +39,14 @@ from tornado.log import gen_log
 from tornado.netutil import ssl_wrap_socket, _client_ssl_defaults, _server_ssl_defaults
 from tornado.util import errno_from_exception
 
+import typing
+from typing import Union, Optional, Awaitable, Callable, Type, Pattern, Any, Dict, TypeVar, Tuple
+from types import TracebackType
+if typing.TYPE_CHECKING:
+    from typing import Deque, List  # noqa: F401
+
+_IOStreamType = TypeVar('_IOStreamType', bound='IOStream')
+
 try:
     from tornado.platform.posix import _set_nonblocking
 except ImportError:
@@ -90,7 +98,7 @@ class StreamClosedError(IOError):
     .. versionchanged:: 4.3
        Added the ``real_error`` attribute.
     """
-    def __init__(self, real_error=None):
+    def __init__(self, real_error: BaseException=None) -> None:
         super(StreamClosedError, self).__init__('Stream is closed')
         self.real_error = real_error
 
@@ -115,21 +123,22 @@ class _StreamBuffer(object):
     of data are encountered.
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         # A sequence of (False, bytearray) and (True, memoryview) objects
-        self._buffers = collections.deque()
+        self._buffers = collections.deque() \
+            # type: Deque[Tuple[bool, Union[bytearray, memoryview]]]
         # Position in the first buffer
         self._first_pos = 0
         self._size = 0
 
-    def __len__(self):
+    def __len__(self) -> int:
         return self._size
 
     # Data above this size will be appended separately instead
     # of extending an existing bytearray
     _large_buf_threshold = 2048
 
-    def append(self, data):
+    def append(self, data: Union[bytes, bytearray, memoryview]) -> None:
         """
         Append the given piece of data (should be a buffer-compatible object).
         """
@@ -147,11 +156,11 @@ class _StreamBuffer(object):
             if new_buf:
                 self._buffers.append((False, bytearray(data)))
             else:
-                b += data
+                b += data  # type: ignore
 
         self._size += size
 
-    def peek(self, size):
+    def peek(self, size: int) -> memoryview:
         """
         Get a view over at most ``size`` bytes (possibly fewer) at the
         current buffer position.
@@ -164,11 +173,11 @@ class _StreamBuffer(object):
 
         pos = self._first_pos
         if is_memview:
-            return b[pos:pos + size]
+            return typing.cast(memoryview, b[pos:pos + size])
         else:
             return memoryview(b)[pos:pos + size]
 
-    def advance(self, size):
+    def advance(self, size: int) -> None:
         """
         Advance the current buffer position by ``size`` bytes.
         """
@@ -191,7 +200,7 @@ class _StreamBuffer(object):
                 # Amortized O(1) shrink for Python 2
                 pos += size
                 if len(b) <= 2 * pos:
-                    del b[:pos]
+                    del typing.cast(bytearray, b)[:pos]
                     pos = 0
                 size = 0
 
@@ -216,8 +225,8 @@ class BaseIOStream(object):
     `read_from_fd`, and optionally `get_fd_error`.
 
     """
-    def __init__(self, max_buffer_size=None,
-                 read_chunk_size=None, max_write_buffer_size=None):
+    def __init__(self, max_buffer_size: int=None,
+                 read_chunk_size: int=None, max_write_buffer_size: int=None) -> None:
         """`BaseIOStream` constructor.
 
         :arg max_buffer_size: Maximum amount of incoming data to buffer;
@@ -241,39 +250,39 @@ class BaseIOStream(object):
         self.read_chunk_size = min(read_chunk_size or 65536,
                                    self.max_buffer_size // 2)
         self.max_write_buffer_size = max_write_buffer_size
-        self.error = None
+        self.error = None  # type: Optional[BaseException]
         self._read_buffer = bytearray()
         self._read_buffer_pos = 0
         self._read_buffer_size = 0
         self._user_read_buffer = False
-        self._after_user_read_buffer = None
+        self._after_user_read_buffer = None  # type: Optional[bytearray]
         self._write_buffer = _StreamBuffer()
         self._total_write_index = 0
         self._total_write_done_index = 0
-        self._read_delimiter = None
-        self._read_regex = None
-        self._read_max_bytes = None
-        self._read_bytes = None
+        self._read_delimiter = None  # type: Optional[bytes]
+        self._read_regex = None  # type: Optional[Pattern]
+        self._read_max_bytes = None  # type: Optional[int]
+        self._read_bytes = None  # type: Optional[int]
         self._read_partial = False
         self._read_until_close = False
-        self._read_future = None
-        self._write_futures = collections.deque()
-        self._close_callback = None
-        self._connect_future = None
+        self._read_future = None  # type: Optional[Future]
+        self._write_futures = collections.deque()  # type: Deque[Tuple[int, Future[None]]]
+        self._close_callback = None  # type: Optional[Callable[[], None]]
+        self._connect_future = None  # type: Optional[Future[IOStream]]
         # _ssl_connect_future should be defined in SSLIOStream
         # but it's here so we can clean it up in _signal_closed
         # TODO: refactor that so subclasses can add additional futures
         # to be cancelled.
-        self._ssl_connect_future = None
+        self._ssl_connect_future = None  # type: Optional[Future[SSLIOStream]]
         self._connecting = False
-        self._state = None
+        self._state = None  # type: Optional[int]
         self._closed = False
 
-    def fileno(self):
+    def fileno(self) -> Union[int, ioloop._Selectable]:
         """Returns the file descriptor for this stream."""
         raise NotImplementedError()
 
-    def close_fd(self):
+    def close_fd(self) -> None:
         """Closes the file underlying this stream.
 
         ``close_fd`` is called by `BaseIOStream` and should not be called
@@ -281,14 +290,14 @@ class BaseIOStream(object):
         """
         raise NotImplementedError()
 
-    def write_to_fd(self, data):
+    def write_to_fd(self, data: memoryview) -> int:
         """Attempts to write ``data`` to the underlying file.
 
         Returns the number of bytes written.
         """
         raise NotImplementedError()
 
-    def read_from_fd(self, buf):
+    def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]:
         """Attempts to read from the underlying file.
 
         Reads up to ``len(buf)`` bytes, storing them in the buffer.
@@ -303,7 +312,7 @@ class BaseIOStream(object):
         """
         raise NotImplementedError()
 
-    def get_fd_error(self):
+    def get_fd_error(self) -> Optional[Exception]:
         """Returns information about any error on the underlying file.
 
         This method is called after the `.IOLoop` has signaled an error on the
@@ -313,7 +322,7 @@ class BaseIOStream(object):
         """
         return None
 
-    def read_until_regex(self, regex, max_bytes=None):
+    def read_until_regex(self, regex: bytes, max_bytes: int=None) -> Awaitable[bytes]:
         """Asynchronously read until we have matched the given regex.
 
         The result includes the data that matches the regex and anything
@@ -350,7 +359,7 @@ class BaseIOStream(object):
             raise
         return future
 
-    def read_until(self, delimiter, max_bytes=None):
+    def read_until(self, delimiter: bytes, max_bytes: int=None) -> Awaitable[bytes]:
         """Asynchronously read until we have found the given delimiter.
 
         The result includes all the data read including the delimiter.
@@ -383,7 +392,7 @@ class BaseIOStream(object):
             raise
         return future
 
-    def read_bytes(self, num_bytes, partial=False):
+    def read_bytes(self, num_bytes: int, partial: bool=False) -> Awaitable[bytes]:
         """Asynchronously read a number of bytes.
 
         If ``partial`` is true, data is returned as soon as we have
@@ -411,7 +420,7 @@ class BaseIOStream(object):
             raise
         return future
 
-    def read_into(self, buf, partial=False):
+    def read_into(self, buf: bytearray, partial: bool=False) -> Awaitable[int]:
         """Asynchronously read a number of bytes.
 
         ``buf`` must be a writable buffer into which data will be read.
@@ -458,7 +467,7 @@ class BaseIOStream(object):
             raise
         return future
 
-    def read_until_close(self):
+    def read_until_close(self) -> Awaitable[bytes]:
         """Asynchronously reads all data from the socket until it is closed.
 
         This will buffer all available data until ``max_buffer_size``
@@ -488,7 +497,7 @@ class BaseIOStream(object):
             raise
         return future
 
-    def write(self, data):
+    def write(self, data: Union[bytes, memoryview]) -> Awaitable[None]:
         """Asynchronously write the given data to this stream.
 
         This method returns a `.Future` that resolves (with a result
@@ -515,7 +524,7 @@ class BaseIOStream(object):
                 raise StreamBufferFullError("Reached maximum write buffer size")
             self._write_buffer.append(data)
             self._total_write_index += len(data)
-        future = Future()
+        future = Future()  # type: Future[None]
         future.add_done_callback(lambda f: f.exception())
         self._write_futures.append((self._total_write_index, future))
         if not self._connecting:
@@ -525,7 +534,7 @@ class BaseIOStream(object):
             self._maybe_add_error_listener()
         return future
 
-    def set_close_callback(self, callback):
+    def set_close_callback(self, callback: Optional[Callable[[], None]]) -> None:
         """Call the given callback when the stream is closed.
 
         This mostly is not necessary for applications that use the
@@ -540,7 +549,10 @@ class BaseIOStream(object):
         self._close_callback = callback
         self._maybe_add_error_listener()
 
-    def close(self, exc_info=False):
+    def close(self, exc_info: Union[None, bool, BaseException,
+                                    Tuple[Optional[Type[BaseException]],
+                                          Optional[BaseException],
+                                          Optional[TracebackType]]]=False) -> None:
         """Close this stream.
 
         If ``exc_info`` is true, set the ``error`` attribute to the current
@@ -567,8 +579,8 @@ class BaseIOStream(object):
             self._closed = True
         self._signal_closed()
 
-    def _signal_closed(self):
-        futures = []
+    def _signal_closed(self) -> None:
+        futures = []  # type: List[Future]
         if self._read_future is not None:
             futures.append(self._read_future)
             self._read_future = None
@@ -583,7 +595,10 @@ class BaseIOStream(object):
         if self._ssl_connect_future is not None:
             # _ssl_connect_future expects to see the real exception (typically
             # an ssl.SSLError), not just StreamClosedError.
-            self._ssl_connect_future.set_exception(self.error)
+            if self.error is not None:
+                self._ssl_connect_future.set_exception(self.error)
+            else:
+                self._ssl_connect_future.set_exception(StreamClosedError())
             self._ssl_connect_future.exception()
             self._ssl_connect_future = None
         if self._close_callback is not None:
@@ -593,21 +608,21 @@ class BaseIOStream(object):
         # Clear the buffers so they can be cleared immediately even
         # if the IOStream object is kept alive by a reference cycle.
         # TODO: Clear the read buffer too; it currently breaks some tests.
-        self._write_buffer = None
+        self._write_buffer = None  # type: ignore
 
-    def reading(self):
+    def reading(self) -> bool:
         """Returns true if we are currently reading from the stream."""
         return self._read_future is not None
 
-    def writing(self):
+    def writing(self) -> bool:
         """Returns true if we are currently writing to the stream."""
         return bool(self._write_buffer)
 
-    def closed(self):
+    def closed(self) -> bool:
         """Returns true if the stream has been closed."""
         return self._closed
 
-    def set_nodelay(self, value):
+    def set_nodelay(self, value: bool) -> None:
         """Sets the no-delay flag for this stream.
 
         By default, data written to TCP streams may be held for a time
@@ -622,7 +637,10 @@ class BaseIOStream(object):
         """
         pass
 
-    def _handle_events(self, fd, events):
+    def _handle_connect(self) -> None:
+        raise NotImplementedError()
+
+    def _handle_events(self, fd: Union[int, ioloop._Selectable], events: int) -> None:
         if self.closed():
             gen_log.warning("Got events for closed stream %s", fd)
             return
@@ -675,10 +693,10 @@ class BaseIOStream(object):
             self.close(exc_info=e)
             raise
 
-    def _read_to_buffer_loop(self):
+    def _read_to_buffer_loop(self) -> Optional[int]:
         # This method is called from _handle_read and _try_inline_read.
         if self._read_bytes is not None:
-            target_bytes = self._read_bytes
+            target_bytes = self._read_bytes  # type: Optional[int]
         elif self._read_max_bytes is not None:
             target_bytes = self._read_max_bytes
         elif self.reading():
@@ -717,7 +735,7 @@ class BaseIOStream(object):
                 next_find_pos = self._read_buffer_size * 2
         return self._find_read_pos()
 
-    def _handle_read(self):
+    def _handle_read(self) -> None:
         try:
             pos = self._read_to_buffer_loop()
         except UnsatisfiableReadError:
@@ -729,19 +747,19 @@ class BaseIOStream(object):
         if pos is not None:
             self._read_from_buffer(pos)
 
-    def _start_read(self):
+    def _start_read(self) -> Future:
         assert self._read_future is None, "Already reading"
         self._read_future = Future()
         return self._read_future
 
-    def _finish_read(self, size, streaming):
+    def _finish_read(self, size: int, streaming: bool) -> None:
         if self._user_read_buffer:
             self._read_buffer = self._after_user_read_buffer or bytearray()
             self._after_user_read_buffer = None
             self._read_buffer_pos = 0
             self._read_buffer_size = len(self._read_buffer)
             self._user_read_buffer = False
-            result = size
+            result = size  # type: Union[int, bytes]
         else:
             result = self._consume(size)
         if self._read_future is not None:
@@ -750,7 +768,7 @@ class BaseIOStream(object):
             future.set_result(result)
         self._maybe_add_error_listener()
 
-    def _try_inline_read(self):
+    def _try_inline_read(self) -> None:
         """Attempt to complete the current read operation from buffered data.
 
         If the read can be completed without blocking, schedules the
@@ -772,7 +790,7 @@ class BaseIOStream(object):
         if not self.closed():
             self._add_io_state(ioloop.IOLoop.READ)
 
-    def _read_to_buffer(self):
+    def _read_to_buffer(self) -> Optional[int]:
         """Reads from the socket and appends the result to the read buffer.
 
         Returns the number of bytes read.  Returns 0 if there is nothing
@@ -783,7 +801,8 @@ class BaseIOStream(object):
             while True:
                 try:
                     if self._user_read_buffer:
-                        buf = memoryview(self._read_buffer)[self._read_buffer_size:]
+                        buf = memoryview(self._read_buffer)[self._read_buffer_size:] \
+                            # type: Union[memoryview, bytearray]
                     else:
                         buf = bytearray(self.read_chunk_size)
                     bytes_read = self.read_from_fd(buf)
@@ -796,7 +815,7 @@ class BaseIOStream(object):
                         # an error to minimize log spam  (the exception will
                         # be available on self.error for apps that care).
                         self.close(exc_info=e)
-                        return
+                        return None
                     self.close(exc_info=e)
                     raise
                 break
@@ -811,14 +830,14 @@ class BaseIOStream(object):
         finally:
             # Break the reference to buf so we don't waste a chunk's worth of
             # memory in case an exception hangs on to our stack frame.
-            buf = None
+            del buf
         if self._read_buffer_size > self.max_buffer_size:
             gen_log.error("Reached maximum read buffer size")
             self.close()
             raise StreamBufferFullError("Reached maximum read buffer size")
         return bytes_read
 
-    def _read_from_buffer(self, pos):
+    def _read_from_buffer(self, pos: int) -> None:
         """Attempts to complete the currently-pending read from the buffer.
 
         The argument is either a position in the read buffer or None,
@@ -828,7 +847,7 @@ class BaseIOStream(object):
         self._read_partial = False
         self._finish_read(pos, False)
 
-    def _find_read_pos(self):
+    def _find_read_pos(self) -> Optional[int]:
         """Attempts to find a position in the read buffer that satisfies
         the currently-pending read.
 
@@ -871,14 +890,14 @@ class BaseIOStream(object):
                 self._check_max_bytes(self._read_regex, self._read_buffer_size)
         return None
 
-    def _check_max_bytes(self, delimiter, size):
+    def _check_max_bytes(self, delimiter: Union[bytes, Pattern], size: int) -> None:
         if (self._read_max_bytes is not None and
                 size > self._read_max_bytes):
             raise UnsatisfiableReadError(
                 "delimiter %r not found within %d bytes" % (
                     delimiter, self._read_max_bytes))
 
-    def _handle_write(self):
+    def _handle_write(self) -> None:
         while True:
             size = len(self._write_buffer)
             if not size:
@@ -918,7 +937,7 @@ class BaseIOStream(object):
             self._write_futures.popleft()
             future.set_result(None)
 
-    def _consume(self, loc):
+    def _consume(self, loc: int) -> bytes:
         # Consume loc bytes from the read buffer and return them
         if loc == 0:
             return b""
@@ -937,11 +956,11 @@ class BaseIOStream(object):
             self._read_buffer_pos = 0
         return b
 
-    def _check_closed(self):
+    def _check_closed(self) -> None:
         if self.closed():
             raise StreamClosedError(real_error=self.error)
 
-    def _maybe_add_error_listener(self):
+    def _maybe_add_error_listener(self) -> None:
         # This method is part of an optimization: to detect a connection that
         # is closed when we're not actively reading or writing, we must listen
         # for read events.  However, it is inefficient to do this when the
@@ -954,7 +973,7 @@ class BaseIOStream(object):
                     self._close_callback is not None):
                 self._add_io_state(ioloop.IOLoop.READ)
 
-    def _add_io_state(self, state):
+    def _add_io_state(self, state: int) -> None:
         """Adds `state` (IOLoop.{READ,WRITE} flags) to our event handler.
 
         Implementation notes: Reads and writes have a fast path and a
@@ -984,7 +1003,7 @@ class BaseIOStream(object):
             self._state = self._state | state
             self.io_loop.update_handler(self.fileno(), self._state)
 
-    def _is_connreset(self, exc):
+    def _is_connreset(self, exc: BaseException) -> bool:
         """Return true if exc is ECONNRESET or equivalent.
 
         May be overridden in subclasses.
@@ -1040,43 +1059,44 @@ class IOStream(BaseIOStream):
        :hide:
 
     """
-    def __init__(self, socket, *args, **kwargs):
+    def __init__(self, socket: socket.socket, *args: Any, **kwargs: Any) -> None:
         self.socket = socket
         self.socket.setblocking(False)
         super(IOStream, self).__init__(*args, **kwargs)
 
-    def fileno(self):
+    def fileno(self) -> Union[int, ioloop._Selectable]:
         return self.socket
 
-    def close_fd(self):
+    def close_fd(self) -> None:
         self.socket.close()
-        self.socket = None
+        self.socket = None  # type: ignore
 
-    def get_fd_error(self):
+    def get_fd_error(self) -> Optional[Exception]:
         errno = self.socket.getsockopt(socket.SOL_SOCKET,
                                        socket.SO_ERROR)
         return socket.error(errno, os.strerror(errno))
 
-    def read_from_fd(self, buf):
+    def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]:
         try:
-            return self.socket.recv_into(buf)
+            return self.socket.recv_into(buf, len(buf))
         except socket.error as e:
             if e.args[0] in _ERRNO_WOULDBLOCK:
                 return None
             else:
                 raise
         finally:
-            buf = None
+            del buf
 
-    def write_to_fd(self, data):
+    def write_to_fd(self, data: memoryview) -> int:
         try:
-            return self.socket.send(data)
+            return self.socket.send(data)  # type: ignore
         finally:
             # Avoid keeping to data, which can be a memoryview.
             # See https://github.com/tornadoweb/tornado/pull/2008
             del data
 
-    def connect(self, address, server_hostname=None):
+    def connect(self: _IOStreamType, address: tuple,
+                server_hostname: str=None) -> 'Future[_IOStreamType]':
         """Connects the socket to a remote address without blocking.
 
         May only be called if the socket passed to the constructor was
@@ -1122,7 +1142,8 @@ class IOStream(BaseIOStream):
 
         """
         self._connecting = True
-        future = self._connect_future = Future()
+        future = Future()  # type: Future[_IOStreamType]
+        self._connect_future = typing.cast('Future[IOStream]', future)
         try:
             self.socket.connect(address)
         except socket.error as e:
@@ -1143,7 +1164,9 @@ class IOStream(BaseIOStream):
         self._add_io_state(self.io_loop.WRITE)
         return future
 
-    def start_tls(self, server_side, ssl_options=None, server_hostname=None):
+    def start_tls(self, server_side: bool,
+                  ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None,
+                  server_hostname: str=None) -> Awaitable['SSLIOStream']:
         """Convert this `IOStream` to an `SSLIOStream`.
 
         This enables protocols that begin in clear-text mode and
@@ -1192,7 +1215,7 @@ class IOStream(BaseIOStream):
 
         socket = self.socket
         self.io_loop.remove_handler(socket)
-        self.socket = None
+        self.socket = None  # type: ignore
         socket = ssl_wrap_socket(socket, ssl_options,
                                  server_hostname=server_hostname,
                                  server_side=server_side,
@@ -1200,7 +1223,7 @@ class IOStream(BaseIOStream):
         orig_close_callback = self._close_callback
         self._close_callback = None
 
-        future = Future()
+        future = Future()  # type: Future[SSLIOStream]
         ssl_stream = SSLIOStream(socket, ssl_options=ssl_options)
         ssl_stream.set_close_callback(orig_close_callback)
         ssl_stream._ssl_connect_future = future
@@ -1208,7 +1231,7 @@ class IOStream(BaseIOStream):
         ssl_stream.read_chunk_size = self.read_chunk_size
         return future
 
-    def _handle_connect(self):
+    def _handle_connect(self) -> None:
         try:
             err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
         except socket.error as e:
@@ -1233,7 +1256,7 @@ class IOStream(BaseIOStream):
             future.set_result(self)
         self._connecting = False
 
-    def set_nodelay(self, value):
+    def set_nodelay(self, value: bool) -> None:
         if (self.socket is not None and
                 self.socket.family in (socket.AF_INET, socket.AF_INET6)):
             try:
@@ -1258,7 +1281,9 @@ class SSLIOStream(IOStream):
     before constructing the `SSLIOStream`.  Unconnected sockets will be
     wrapped when `IOStream.connect` is finished.
     """
-    def __init__(self, *args, **kwargs):
+    socket = None  # type: ssl.SSLSocket
+
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
         """The ``ssl_options`` keyword argument may either be an
         `ssl.SSLContext` object or a dictionary of keywords arguments
         for `ssl.wrap_socket`
@@ -1268,7 +1293,7 @@ class SSLIOStream(IOStream):
         self._ssl_accepting = True
         self._handshake_reading = False
         self._handshake_writing = False
-        self._server_hostname = None
+        self._server_hostname = None  # type: Optional[str]
 
         # If the socket is already connected, attempt to start the handshake.
         try:
@@ -1281,13 +1306,13 @@ class SSLIOStream(IOStream):
             # _handle_events.
             self._add_io_state(self.io_loop.WRITE)
 
-    def reading(self):
+    def reading(self) -> bool:
         return self._handshake_reading or super(SSLIOStream, self).reading()
 
-    def writing(self):
+    def writing(self) -> bool:
         return self._handshake_writing or super(SSLIOStream, self).writing()
 
-    def _do_ssl_handshake(self):
+    def _do_ssl_handshake(self) -> None:
         # Based on code from test_ssl.py in the python stdlib
         try:
             self._handshake_reading = False
@@ -1333,13 +1358,13 @@ class SSLIOStream(IOStream):
                 return
             self._finish_ssl_connect()
 
-    def _finish_ssl_connect(self):
+    def _finish_ssl_connect(self) -> None:
         if self._ssl_connect_future is not None:
             future = self._ssl_connect_future
             self._ssl_connect_future = None
             future.set_result(self)
 
-    def _verify_cert(self, peercert):
+    def _verify_cert(self, peercert: Any) -> bool:
         """Returns True if peercert is valid according to the configured
         validation mode and hostname.
 
@@ -1366,19 +1391,19 @@ class SSLIOStream(IOStream):
         else:
             return True
 
-    def _handle_read(self):
+    def _handle_read(self) -> None:
         if self._ssl_accepting:
             self._do_ssl_handshake()
             return
         super(SSLIOStream, self)._handle_read()
 
-    def _handle_write(self):
+    def _handle_write(self) -> None:
         if self._ssl_accepting:
             self._do_ssl_handshake()
             return
         super(SSLIOStream, self)._handle_write()
 
-    def connect(self, address, server_hostname=None):
+    def connect(self, address: Tuple, server_hostname: str=None) -> 'Future[SSLIOStream]':
         self._server_hostname = server_hostname
         # Ignore the result of connect(). If it fails,
         # wait_for_handshake will raise an error too. This is
@@ -1395,7 +1420,7 @@ class SSLIOStream(IOStream):
         fut.add_done_callback(lambda f: f.exception())
         return self.wait_for_handshake()
 
-    def _handle_connect(self):
+    def _handle_connect(self) -> None:
         # Call the superclass method to check for errors.
         super(SSLIOStream, self)._handle_connect()
         if self.closed():
@@ -1412,13 +1437,14 @@ class SSLIOStream(IOStream):
         # wrap_socket().
         self.io_loop.remove_handler(self.socket)
         old_state = self._state
+        assert old_state is not None
         self._state = None
         self.socket = ssl_wrap_socket(self.socket, self._ssl_options,
                                       server_hostname=self._server_hostname,
                                       do_handshake_on_connect=False)
         self._add_io_state(old_state)
 
-    def wait_for_handshake(self):
+    def wait_for_handshake(self) -> 'Future[SSLIOStream]':
         """Wait for the initial SSL handshake to complete.
 
         If a ``callback`` is given, it will be called with no
@@ -1450,9 +1476,9 @@ class SSLIOStream(IOStream):
             self._finish_ssl_connect()
         return future
 
-    def write_to_fd(self, data):
+    def write_to_fd(self, data: memoryview) -> int:
         try:
-            return self.socket.send(data)
+            return self.socket.send(data)  # type: ignore
         except ssl.SSLError as e:
             if e.args[0] == ssl.SSL_ERROR_WANT_WRITE:
                 # In Python 3.5+, SSLSocket.send raises a WANT_WRITE error if
@@ -1468,7 +1494,7 @@ class SSLIOStream(IOStream):
             # See https://github.com/tornadoweb/tornado/pull/2008
             del data
 
-    def read_from_fd(self, buf):
+    def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]:
         try:
             if self._ssl_accepting:
                 # If the handshake hasn't finished yet, there can't be anything
@@ -1476,7 +1502,7 @@ class SSLIOStream(IOStream):
                 # depending on the SSL version)
                 return None
             try:
-                return self.socket.recv_into(buf)
+                return self.socket.recv_into(buf, len(buf))
             except ssl.SSLError as e:
                 # SSLError is a subclass of socket.error, so this except
                 # block must come first.
@@ -1490,9 +1516,9 @@ class SSLIOStream(IOStream):
                 else:
                     raise
         finally:
-            buf = None
+            del buf
 
-    def _is_connreset(self, e):
+    def _is_connreset(self, e: BaseException) -> bool:
         if isinstance(e, ssl.SSLError) and e.args[0] == ssl.SSL_ERROR_EOF:
             return True
         return super(SSLIOStream, self)._is_connreset(e)
@@ -1506,29 +1532,29 @@ class PipeIOStream(BaseIOStream):
     one-way, so a `PipeIOStream` can be used for reading or writing but not
     both.
     """
-    def __init__(self, fd, *args, **kwargs):
+    def __init__(self, fd: int, *args: Any, **kwargs: Any) -> None:
         self.fd = fd
         self._fio = io.FileIO(self.fd, "r+")
         _set_nonblocking(fd)
         super(PipeIOStream, self).__init__(*args, **kwargs)
 
-    def fileno(self):
+    def fileno(self) -> int:
         return self.fd
 
-    def close_fd(self):
+    def close_fd(self) -> None:
         self._fio.close()
 
-    def write_to_fd(self, data):
+    def write_to_fd(self, data: memoryview) -> int:
         try:
-            return os.write(self.fd, data)
+            return os.write(self.fd, data)  # type: ignore
         finally:
             # Avoid keeping to data, which can be a memoryview.
             # See https://github.com/tornadoweb/tornado/pull/2008
             del data
 
-    def read_from_fd(self, buf):
+    def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]:
         try:
-            return self._fio.readinto(buf)
+            return self._fio.readinto(buf)  # type: ignore
         except (IOError, OSError) as e:
             if errno_from_exception(e) == errno.EBADF:
                 # If the writing half of a pipe is closed, select will
@@ -1538,9 +1564,9 @@ class PipeIOStream(BaseIOStream):
             else:
                 raise
         finally:
-            buf = None
+            del buf
 
 
-def doctests():
+def doctests() -> Any:
     import doctest
     return doctest.DocTestSuite()
index 9c6dbf76e6c346de3ea18e6deac4e03755a51f25..63d84ab1e33c5d6f137d32ce378486204d47ab52 100644 (file)
@@ -649,7 +649,7 @@ class TestIOStreamMixin(TestReadWriteMixin):
     @gen.coroutine
     def make_iostream_pair(self, **kwargs):
         listener, port = bind_unused_port()
-        server_stream_fut = Future()
+        server_stream_fut = Future()  # type: Future[IOStream]
 
         def accept_callback(connection, address):
             server_stream_fut.set_result(self._make_server_iostream(connection, **kwargs))
@@ -679,11 +679,11 @@ class TestIOStreamMixin(TestReadWriteMixin):
 
         self.assertTrue(isinstance(stream.error, socket.error), stream.error)
         if sys.platform != 'cygwin':
-            _ERRNO_CONNREFUSED = (errno.ECONNREFUSED,)
+            _ERRNO_CONNREFUSED = [errno.ECONNREFUSED]
             if hasattr(errno, "WSAECONNREFUSED"):
-                _ERRNO_CONNREFUSED += (errno.WSAECONNREFUSED,)
+                _ERRNO_CONNREFUSED.append(errno.WSAECONNREFUSED)  # type: ignore
             # cygwin's errnos don't match those used on native windows python
-            self.assertTrue(stream.error.args[0] in _ERRNO_CONNREFUSED)
+            self.assertTrue(stream.error.args[0] in _ERRNO_CONNREFUSED)  # type: ignore
 
     @gen_test
     def test_gaierror(self):
@@ -849,7 +849,7 @@ class TestIOStreamStartTLS(AsyncTestCase):
             super(TestIOStreamStartTLS, self).setUp()
             self.listener, self.port = bind_unused_port()
             self.server_stream = None
-            self.server_accepted = Future()
+            self.server_accepted = Future()  # type: Future[None]
             netutil.add_accept_handler(self.listener, self.accept)
             self.client_stream = IOStream(socket.socket())
             self.io_loop.add_future(self.client_stream.connect(
@@ -969,7 +969,7 @@ class WaitForHandshakeTest(AsyncTestCase):
     @gen_test
     def test_wait_for_handshake_future(self):
         test = self
-        handshake_future = Future()
+        handshake_future = Future()  # type: Future[None]
 
         class TestServer(TCPServer):
             def handle_stream(self, stream, address):
@@ -987,7 +987,7 @@ class WaitForHandshakeTest(AsyncTestCase):
     @gen_test
     def test_wait_for_handshake_already_waiting_error(self):
         test = self
-        handshake_future = Future()
+        handshake_future = Future()  # type: Future[None]
 
         class TestServer(TCPServer):
             @gen.coroutine
@@ -1003,7 +1003,7 @@ class WaitForHandshakeTest(AsyncTestCase):
 
     @gen_test
     def test_wait_for_handshake_already_connected(self):
-        handshake_future = Future()
+        handshake_future = Future()  # type: Future[None]
 
         class TestServer(TCPServer):
             @gen.coroutine