From: Ben Darnell Date: Sun, 12 Aug 2018 16:26:00 +0000 (-0400) Subject: iostream: Add type annotations X-Git-Tag: v6.0.0b1~33^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0c7402842b0182efcd2dc8f95167bd46a3aaa6e0;p=thirdparty%2Ftornado.git iostream: Add type annotations --- diff --git a/setup.cfg b/setup.cfg index cde5d24f2..c843ce7b3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,6 +22,9 @@ disallow_untyped_defs = True [mypy-tornado.ioloop] disallow_untyped_defs = True +[mypy-tornado.iostream] +disallow_untyped_defs = True + [mypy-tornado.locale] disallow_untyped_defs = True @@ -63,6 +66,9 @@ check_untyped_defs = True [mypy-tornado.test.ioloop_test] check_untyped_defs = True +[mypy-tornado.test.iostream_test] +check_untyped_defs = True + [mypy-tornado.test.locale_test] check_untyped_defs = True diff --git a/tornado/iostream.py b/tornado/iostream.py index 514c36f50..fe41cebaf 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -39,6 +39,14 @@ from tornado.log import gen_log from tornado.netutil import ssl_wrap_socket, _client_ssl_defaults, _server_ssl_defaults from tornado.util import errno_from_exception +import typing +from typing import Union, Optional, Awaitable, Callable, Type, Pattern, Any, Dict, TypeVar, Tuple +from types import TracebackType +if typing.TYPE_CHECKING: + from typing import Deque, List # noqa: F401 + +_IOStreamType = TypeVar('_IOStreamType', bound='IOStream') + try: from tornado.platform.posix import _set_nonblocking except ImportError: @@ -90,7 +98,7 @@ class StreamClosedError(IOError): .. versionchanged:: 4.3 Added the ``real_error`` attribute. """ - def __init__(self, real_error=None): + def __init__(self, real_error: BaseException=None) -> None: super(StreamClosedError, self).__init__('Stream is closed') self.real_error = real_error @@ -115,21 +123,22 @@ class _StreamBuffer(object): of data are encountered. """ - def __init__(self): + def __init__(self) -> None: # A sequence of (False, bytearray) and (True, memoryview) objects - self._buffers = collections.deque() + self._buffers = collections.deque() \ + # type: Deque[Tuple[bool, Union[bytearray, memoryview]]] # Position in the first buffer self._first_pos = 0 self._size = 0 - def __len__(self): + def __len__(self) -> int: return self._size # Data above this size will be appended separately instead # of extending an existing bytearray _large_buf_threshold = 2048 - def append(self, data): + def append(self, data: Union[bytes, bytearray, memoryview]) -> None: """ Append the given piece of data (should be a buffer-compatible object). """ @@ -147,11 +156,11 @@ class _StreamBuffer(object): if new_buf: self._buffers.append((False, bytearray(data))) else: - b += data + b += data # type: ignore self._size += size - def peek(self, size): + def peek(self, size: int) -> memoryview: """ Get a view over at most ``size`` bytes (possibly fewer) at the current buffer position. @@ -164,11 +173,11 @@ class _StreamBuffer(object): pos = self._first_pos if is_memview: - return b[pos:pos + size] + return typing.cast(memoryview, b[pos:pos + size]) else: return memoryview(b)[pos:pos + size] - def advance(self, size): + def advance(self, size: int) -> None: """ Advance the current buffer position by ``size`` bytes. """ @@ -191,7 +200,7 @@ class _StreamBuffer(object): # Amortized O(1) shrink for Python 2 pos += size if len(b) <= 2 * pos: - del b[:pos] + del typing.cast(bytearray, b)[:pos] pos = 0 size = 0 @@ -216,8 +225,8 @@ class BaseIOStream(object): `read_from_fd`, and optionally `get_fd_error`. """ - def __init__(self, max_buffer_size=None, - read_chunk_size=None, max_write_buffer_size=None): + def __init__(self, max_buffer_size: int=None, + read_chunk_size: int=None, max_write_buffer_size: int=None) -> None: """`BaseIOStream` constructor. :arg max_buffer_size: Maximum amount of incoming data to buffer; @@ -241,39 +250,39 @@ class BaseIOStream(object): self.read_chunk_size = min(read_chunk_size or 65536, self.max_buffer_size // 2) self.max_write_buffer_size = max_write_buffer_size - self.error = None + self.error = None # type: Optional[BaseException] self._read_buffer = bytearray() self._read_buffer_pos = 0 self._read_buffer_size = 0 self._user_read_buffer = False - self._after_user_read_buffer = None + self._after_user_read_buffer = None # type: Optional[bytearray] self._write_buffer = _StreamBuffer() self._total_write_index = 0 self._total_write_done_index = 0 - self._read_delimiter = None - self._read_regex = None - self._read_max_bytes = None - self._read_bytes = None + self._read_delimiter = None # type: Optional[bytes] + self._read_regex = None # type: Optional[Pattern] + self._read_max_bytes = None # type: Optional[int] + self._read_bytes = None # type: Optional[int] self._read_partial = False self._read_until_close = False - self._read_future = None - self._write_futures = collections.deque() - self._close_callback = None - self._connect_future = None + self._read_future = None # type: Optional[Future] + self._write_futures = collections.deque() # type: Deque[Tuple[int, Future[None]]] + self._close_callback = None # type: Optional[Callable[[], None]] + self._connect_future = None # type: Optional[Future[IOStream]] # _ssl_connect_future should be defined in SSLIOStream # but it's here so we can clean it up in _signal_closed # TODO: refactor that so subclasses can add additional futures # to be cancelled. - self._ssl_connect_future = None + self._ssl_connect_future = None # type: Optional[Future[SSLIOStream]] self._connecting = False - self._state = None + self._state = None # type: Optional[int] self._closed = False - def fileno(self): + def fileno(self) -> Union[int, ioloop._Selectable]: """Returns the file descriptor for this stream.""" raise NotImplementedError() - def close_fd(self): + def close_fd(self) -> None: """Closes the file underlying this stream. ``close_fd`` is called by `BaseIOStream` and should not be called @@ -281,14 +290,14 @@ class BaseIOStream(object): """ raise NotImplementedError() - def write_to_fd(self, data): + def write_to_fd(self, data: memoryview) -> int: """Attempts to write ``data`` to the underlying file. Returns the number of bytes written. """ raise NotImplementedError() - def read_from_fd(self, buf): + def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]: """Attempts to read from the underlying file. Reads up to ``len(buf)`` bytes, storing them in the buffer. @@ -303,7 +312,7 @@ class BaseIOStream(object): """ raise NotImplementedError() - def get_fd_error(self): + def get_fd_error(self) -> Optional[Exception]: """Returns information about any error on the underlying file. This method is called after the `.IOLoop` has signaled an error on the @@ -313,7 +322,7 @@ class BaseIOStream(object): """ return None - def read_until_regex(self, regex, max_bytes=None): + def read_until_regex(self, regex: bytes, max_bytes: int=None) -> Awaitable[bytes]: """Asynchronously read until we have matched the given regex. The result includes the data that matches the regex and anything @@ -350,7 +359,7 @@ class BaseIOStream(object): raise return future - def read_until(self, delimiter, max_bytes=None): + def read_until(self, delimiter: bytes, max_bytes: int=None) -> Awaitable[bytes]: """Asynchronously read until we have found the given delimiter. The result includes all the data read including the delimiter. @@ -383,7 +392,7 @@ class BaseIOStream(object): raise return future - def read_bytes(self, num_bytes, partial=False): + def read_bytes(self, num_bytes: int, partial: bool=False) -> Awaitable[bytes]: """Asynchronously read a number of bytes. If ``partial`` is true, data is returned as soon as we have @@ -411,7 +420,7 @@ class BaseIOStream(object): raise return future - def read_into(self, buf, partial=False): + def read_into(self, buf: bytearray, partial: bool=False) -> Awaitable[int]: """Asynchronously read a number of bytes. ``buf`` must be a writable buffer into which data will be read. @@ -458,7 +467,7 @@ class BaseIOStream(object): raise return future - def read_until_close(self): + def read_until_close(self) -> Awaitable[bytes]: """Asynchronously reads all data from the socket until it is closed. This will buffer all available data until ``max_buffer_size`` @@ -488,7 +497,7 @@ class BaseIOStream(object): raise return future - def write(self, data): + def write(self, data: Union[bytes, memoryview]) -> Awaitable[None]: """Asynchronously write the given data to this stream. This method returns a `.Future` that resolves (with a result @@ -515,7 +524,7 @@ class BaseIOStream(object): raise StreamBufferFullError("Reached maximum write buffer size") self._write_buffer.append(data) self._total_write_index += len(data) - future = Future() + future = Future() # type: Future[None] future.add_done_callback(lambda f: f.exception()) self._write_futures.append((self._total_write_index, future)) if not self._connecting: @@ -525,7 +534,7 @@ class BaseIOStream(object): self._maybe_add_error_listener() return future - def set_close_callback(self, callback): + def set_close_callback(self, callback: Optional[Callable[[], None]]) -> None: """Call the given callback when the stream is closed. This mostly is not necessary for applications that use the @@ -540,7 +549,10 @@ class BaseIOStream(object): self._close_callback = callback self._maybe_add_error_listener() - def close(self, exc_info=False): + def close(self, exc_info: Union[None, bool, BaseException, + Tuple[Optional[Type[BaseException]], + Optional[BaseException], + Optional[TracebackType]]]=False) -> None: """Close this stream. If ``exc_info`` is true, set the ``error`` attribute to the current @@ -567,8 +579,8 @@ class BaseIOStream(object): self._closed = True self._signal_closed() - def _signal_closed(self): - futures = [] + def _signal_closed(self) -> None: + futures = [] # type: List[Future] if self._read_future is not None: futures.append(self._read_future) self._read_future = None @@ -583,7 +595,10 @@ class BaseIOStream(object): if self._ssl_connect_future is not None: # _ssl_connect_future expects to see the real exception (typically # an ssl.SSLError), not just StreamClosedError. - self._ssl_connect_future.set_exception(self.error) + if self.error is not None: + self._ssl_connect_future.set_exception(self.error) + else: + self._ssl_connect_future.set_exception(StreamClosedError()) self._ssl_connect_future.exception() self._ssl_connect_future = None if self._close_callback is not None: @@ -593,21 +608,21 @@ class BaseIOStream(object): # Clear the buffers so they can be cleared immediately even # if the IOStream object is kept alive by a reference cycle. # TODO: Clear the read buffer too; it currently breaks some tests. - self._write_buffer = None + self._write_buffer = None # type: ignore - def reading(self): + def reading(self) -> bool: """Returns true if we are currently reading from the stream.""" return self._read_future is not None - def writing(self): + def writing(self) -> bool: """Returns true if we are currently writing to the stream.""" return bool(self._write_buffer) - def closed(self): + def closed(self) -> bool: """Returns true if the stream has been closed.""" return self._closed - def set_nodelay(self, value): + def set_nodelay(self, value: bool) -> None: """Sets the no-delay flag for this stream. By default, data written to TCP streams may be held for a time @@ -622,7 +637,10 @@ class BaseIOStream(object): """ pass - def _handle_events(self, fd, events): + def _handle_connect(self) -> None: + raise NotImplementedError() + + def _handle_events(self, fd: Union[int, ioloop._Selectable], events: int) -> None: if self.closed(): gen_log.warning("Got events for closed stream %s", fd) return @@ -675,10 +693,10 @@ class BaseIOStream(object): self.close(exc_info=e) raise - def _read_to_buffer_loop(self): + def _read_to_buffer_loop(self) -> Optional[int]: # This method is called from _handle_read and _try_inline_read. if self._read_bytes is not None: - target_bytes = self._read_bytes + target_bytes = self._read_bytes # type: Optional[int] elif self._read_max_bytes is not None: target_bytes = self._read_max_bytes elif self.reading(): @@ -717,7 +735,7 @@ class BaseIOStream(object): next_find_pos = self._read_buffer_size * 2 return self._find_read_pos() - def _handle_read(self): + def _handle_read(self) -> None: try: pos = self._read_to_buffer_loop() except UnsatisfiableReadError: @@ -729,19 +747,19 @@ class BaseIOStream(object): if pos is not None: self._read_from_buffer(pos) - def _start_read(self): + def _start_read(self) -> Future: assert self._read_future is None, "Already reading" self._read_future = Future() return self._read_future - def _finish_read(self, size, streaming): + def _finish_read(self, size: int, streaming: bool) -> None: if self._user_read_buffer: self._read_buffer = self._after_user_read_buffer or bytearray() self._after_user_read_buffer = None self._read_buffer_pos = 0 self._read_buffer_size = len(self._read_buffer) self._user_read_buffer = False - result = size + result = size # type: Union[int, bytes] else: result = self._consume(size) if self._read_future is not None: @@ -750,7 +768,7 @@ class BaseIOStream(object): future.set_result(result) self._maybe_add_error_listener() - def _try_inline_read(self): + def _try_inline_read(self) -> None: """Attempt to complete the current read operation from buffered data. If the read can be completed without blocking, schedules the @@ -772,7 +790,7 @@ class BaseIOStream(object): if not self.closed(): self._add_io_state(ioloop.IOLoop.READ) - def _read_to_buffer(self): + def _read_to_buffer(self) -> Optional[int]: """Reads from the socket and appends the result to the read buffer. Returns the number of bytes read. Returns 0 if there is nothing @@ -783,7 +801,8 @@ class BaseIOStream(object): while True: try: if self._user_read_buffer: - buf = memoryview(self._read_buffer)[self._read_buffer_size:] + buf = memoryview(self._read_buffer)[self._read_buffer_size:] \ + # type: Union[memoryview, bytearray] else: buf = bytearray(self.read_chunk_size) bytes_read = self.read_from_fd(buf) @@ -796,7 +815,7 @@ class BaseIOStream(object): # an error to minimize log spam (the exception will # be available on self.error for apps that care). self.close(exc_info=e) - return + return None self.close(exc_info=e) raise break @@ -811,14 +830,14 @@ class BaseIOStream(object): finally: # Break the reference to buf so we don't waste a chunk's worth of # memory in case an exception hangs on to our stack frame. - buf = None + del buf if self._read_buffer_size > self.max_buffer_size: gen_log.error("Reached maximum read buffer size") self.close() raise StreamBufferFullError("Reached maximum read buffer size") return bytes_read - def _read_from_buffer(self, pos): + def _read_from_buffer(self, pos: int) -> None: """Attempts to complete the currently-pending read from the buffer. The argument is either a position in the read buffer or None, @@ -828,7 +847,7 @@ class BaseIOStream(object): self._read_partial = False self._finish_read(pos, False) - def _find_read_pos(self): + def _find_read_pos(self) -> Optional[int]: """Attempts to find a position in the read buffer that satisfies the currently-pending read. @@ -871,14 +890,14 @@ class BaseIOStream(object): self._check_max_bytes(self._read_regex, self._read_buffer_size) return None - def _check_max_bytes(self, delimiter, size): + def _check_max_bytes(self, delimiter: Union[bytes, Pattern], size: int) -> None: if (self._read_max_bytes is not None and size > self._read_max_bytes): raise UnsatisfiableReadError( "delimiter %r not found within %d bytes" % ( delimiter, self._read_max_bytes)) - def _handle_write(self): + def _handle_write(self) -> None: while True: size = len(self._write_buffer) if not size: @@ -918,7 +937,7 @@ class BaseIOStream(object): self._write_futures.popleft() future.set_result(None) - def _consume(self, loc): + def _consume(self, loc: int) -> bytes: # Consume loc bytes from the read buffer and return them if loc == 0: return b"" @@ -937,11 +956,11 @@ class BaseIOStream(object): self._read_buffer_pos = 0 return b - def _check_closed(self): + def _check_closed(self) -> None: if self.closed(): raise StreamClosedError(real_error=self.error) - def _maybe_add_error_listener(self): + def _maybe_add_error_listener(self) -> None: # This method is part of an optimization: to detect a connection that # is closed when we're not actively reading or writing, we must listen # for read events. However, it is inefficient to do this when the @@ -954,7 +973,7 @@ class BaseIOStream(object): self._close_callback is not None): self._add_io_state(ioloop.IOLoop.READ) - def _add_io_state(self, state): + def _add_io_state(self, state: int) -> None: """Adds `state` (IOLoop.{READ,WRITE} flags) to our event handler. Implementation notes: Reads and writes have a fast path and a @@ -984,7 +1003,7 @@ class BaseIOStream(object): self._state = self._state | state self.io_loop.update_handler(self.fileno(), self._state) - def _is_connreset(self, exc): + def _is_connreset(self, exc: BaseException) -> bool: """Return true if exc is ECONNRESET or equivalent. May be overridden in subclasses. @@ -1040,43 +1059,44 @@ class IOStream(BaseIOStream): :hide: """ - def __init__(self, socket, *args, **kwargs): + def __init__(self, socket: socket.socket, *args: Any, **kwargs: Any) -> None: self.socket = socket self.socket.setblocking(False) super(IOStream, self).__init__(*args, **kwargs) - def fileno(self): + def fileno(self) -> Union[int, ioloop._Selectable]: return self.socket - def close_fd(self): + def close_fd(self) -> None: self.socket.close() - self.socket = None + self.socket = None # type: ignore - def get_fd_error(self): + def get_fd_error(self) -> Optional[Exception]: errno = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) return socket.error(errno, os.strerror(errno)) - def read_from_fd(self, buf): + def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]: try: - return self.socket.recv_into(buf) + return self.socket.recv_into(buf, len(buf)) except socket.error as e: if e.args[0] in _ERRNO_WOULDBLOCK: return None else: raise finally: - buf = None + del buf - def write_to_fd(self, data): + def write_to_fd(self, data: memoryview) -> int: try: - return self.socket.send(data) + return self.socket.send(data) # type: ignore finally: # Avoid keeping to data, which can be a memoryview. # See https://github.com/tornadoweb/tornado/pull/2008 del data - def connect(self, address, server_hostname=None): + def connect(self: _IOStreamType, address: tuple, + server_hostname: str=None) -> 'Future[_IOStreamType]': """Connects the socket to a remote address without blocking. May only be called if the socket passed to the constructor was @@ -1122,7 +1142,8 @@ class IOStream(BaseIOStream): """ self._connecting = True - future = self._connect_future = Future() + future = Future() # type: Future[_IOStreamType] + self._connect_future = typing.cast('Future[IOStream]', future) try: self.socket.connect(address) except socket.error as e: @@ -1143,7 +1164,9 @@ class IOStream(BaseIOStream): self._add_io_state(self.io_loop.WRITE) return future - def start_tls(self, server_side, ssl_options=None, server_hostname=None): + def start_tls(self, server_side: bool, + ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None, + server_hostname: str=None) -> Awaitable['SSLIOStream']: """Convert this `IOStream` to an `SSLIOStream`. This enables protocols that begin in clear-text mode and @@ -1192,7 +1215,7 @@ class IOStream(BaseIOStream): socket = self.socket self.io_loop.remove_handler(socket) - self.socket = None + self.socket = None # type: ignore socket = ssl_wrap_socket(socket, ssl_options, server_hostname=server_hostname, server_side=server_side, @@ -1200,7 +1223,7 @@ class IOStream(BaseIOStream): orig_close_callback = self._close_callback self._close_callback = None - future = Future() + future = Future() # type: Future[SSLIOStream] ssl_stream = SSLIOStream(socket, ssl_options=ssl_options) ssl_stream.set_close_callback(orig_close_callback) ssl_stream._ssl_connect_future = future @@ -1208,7 +1231,7 @@ class IOStream(BaseIOStream): ssl_stream.read_chunk_size = self.read_chunk_size return future - def _handle_connect(self): + def _handle_connect(self) -> None: try: err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) except socket.error as e: @@ -1233,7 +1256,7 @@ class IOStream(BaseIOStream): future.set_result(self) self._connecting = False - def set_nodelay(self, value): + def set_nodelay(self, value: bool) -> None: if (self.socket is not None and self.socket.family in (socket.AF_INET, socket.AF_INET6)): try: @@ -1258,7 +1281,9 @@ class SSLIOStream(IOStream): before constructing the `SSLIOStream`. Unconnected sockets will be wrapped when `IOStream.connect` is finished. """ - def __init__(self, *args, **kwargs): + socket = None # type: ssl.SSLSocket + + def __init__(self, *args: Any, **kwargs: Any) -> None: """The ``ssl_options`` keyword argument may either be an `ssl.SSLContext` object or a dictionary of keywords arguments for `ssl.wrap_socket` @@ -1268,7 +1293,7 @@ class SSLIOStream(IOStream): self._ssl_accepting = True self._handshake_reading = False self._handshake_writing = False - self._server_hostname = None + self._server_hostname = None # type: Optional[str] # If the socket is already connected, attempt to start the handshake. try: @@ -1281,13 +1306,13 @@ class SSLIOStream(IOStream): # _handle_events. self._add_io_state(self.io_loop.WRITE) - def reading(self): + def reading(self) -> bool: return self._handshake_reading or super(SSLIOStream, self).reading() - def writing(self): + def writing(self) -> bool: return self._handshake_writing or super(SSLIOStream, self).writing() - def _do_ssl_handshake(self): + def _do_ssl_handshake(self) -> None: # Based on code from test_ssl.py in the python stdlib try: self._handshake_reading = False @@ -1333,13 +1358,13 @@ class SSLIOStream(IOStream): return self._finish_ssl_connect() - def _finish_ssl_connect(self): + def _finish_ssl_connect(self) -> None: if self._ssl_connect_future is not None: future = self._ssl_connect_future self._ssl_connect_future = None future.set_result(self) - def _verify_cert(self, peercert): + def _verify_cert(self, peercert: Any) -> bool: """Returns True if peercert is valid according to the configured validation mode and hostname. @@ -1366,19 +1391,19 @@ class SSLIOStream(IOStream): else: return True - def _handle_read(self): + def _handle_read(self) -> None: if self._ssl_accepting: self._do_ssl_handshake() return super(SSLIOStream, self)._handle_read() - def _handle_write(self): + def _handle_write(self) -> None: if self._ssl_accepting: self._do_ssl_handshake() return super(SSLIOStream, self)._handle_write() - def connect(self, address, server_hostname=None): + def connect(self, address: Tuple, server_hostname: str=None) -> 'Future[SSLIOStream]': self._server_hostname = server_hostname # Ignore the result of connect(). If it fails, # wait_for_handshake will raise an error too. This is @@ -1395,7 +1420,7 @@ class SSLIOStream(IOStream): fut.add_done_callback(lambda f: f.exception()) return self.wait_for_handshake() - def _handle_connect(self): + def _handle_connect(self) -> None: # Call the superclass method to check for errors. super(SSLIOStream, self)._handle_connect() if self.closed(): @@ -1412,13 +1437,14 @@ class SSLIOStream(IOStream): # wrap_socket(). self.io_loop.remove_handler(self.socket) old_state = self._state + assert old_state is not None self._state = None self.socket = ssl_wrap_socket(self.socket, self._ssl_options, server_hostname=self._server_hostname, do_handshake_on_connect=False) self._add_io_state(old_state) - def wait_for_handshake(self): + def wait_for_handshake(self) -> 'Future[SSLIOStream]': """Wait for the initial SSL handshake to complete. If a ``callback`` is given, it will be called with no @@ -1450,9 +1476,9 @@ class SSLIOStream(IOStream): self._finish_ssl_connect() return future - def write_to_fd(self, data): + def write_to_fd(self, data: memoryview) -> int: try: - return self.socket.send(data) + return self.socket.send(data) # type: ignore except ssl.SSLError as e: if e.args[0] == ssl.SSL_ERROR_WANT_WRITE: # In Python 3.5+, SSLSocket.send raises a WANT_WRITE error if @@ -1468,7 +1494,7 @@ class SSLIOStream(IOStream): # See https://github.com/tornadoweb/tornado/pull/2008 del data - def read_from_fd(self, buf): + def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]: try: if self._ssl_accepting: # If the handshake hasn't finished yet, there can't be anything @@ -1476,7 +1502,7 @@ class SSLIOStream(IOStream): # depending on the SSL version) return None try: - return self.socket.recv_into(buf) + return self.socket.recv_into(buf, len(buf)) except ssl.SSLError as e: # SSLError is a subclass of socket.error, so this except # block must come first. @@ -1490,9 +1516,9 @@ class SSLIOStream(IOStream): else: raise finally: - buf = None + del buf - def _is_connreset(self, e): + def _is_connreset(self, e: BaseException) -> bool: if isinstance(e, ssl.SSLError) and e.args[0] == ssl.SSL_ERROR_EOF: return True return super(SSLIOStream, self)._is_connreset(e) @@ -1506,29 +1532,29 @@ class PipeIOStream(BaseIOStream): one-way, so a `PipeIOStream` can be used for reading or writing but not both. """ - def __init__(self, fd, *args, **kwargs): + def __init__(self, fd: int, *args: Any, **kwargs: Any) -> None: self.fd = fd self._fio = io.FileIO(self.fd, "r+") _set_nonblocking(fd) super(PipeIOStream, self).__init__(*args, **kwargs) - def fileno(self): + def fileno(self) -> int: return self.fd - def close_fd(self): + def close_fd(self) -> None: self._fio.close() - def write_to_fd(self, data): + def write_to_fd(self, data: memoryview) -> int: try: - return os.write(self.fd, data) + return os.write(self.fd, data) # type: ignore finally: # Avoid keeping to data, which can be a memoryview. # See https://github.com/tornadoweb/tornado/pull/2008 del data - def read_from_fd(self, buf): + def read_from_fd(self, buf: Union[bytearray, memoryview]) -> Optional[int]: try: - return self._fio.readinto(buf) + return self._fio.readinto(buf) # type: ignore except (IOError, OSError) as e: if errno_from_exception(e) == errno.EBADF: # If the writing half of a pipe is closed, select will @@ -1538,9 +1564,9 @@ class PipeIOStream(BaseIOStream): else: raise finally: - buf = None + del buf -def doctests(): +def doctests() -> Any: import doctest return doctest.DocTestSuite() diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py index 9c6dbf76e..63d84ab1e 100644 --- a/tornado/test/iostream_test.py +++ b/tornado/test/iostream_test.py @@ -649,7 +649,7 @@ class TestIOStreamMixin(TestReadWriteMixin): @gen.coroutine def make_iostream_pair(self, **kwargs): listener, port = bind_unused_port() - server_stream_fut = Future() + server_stream_fut = Future() # type: Future[IOStream] def accept_callback(connection, address): server_stream_fut.set_result(self._make_server_iostream(connection, **kwargs)) @@ -679,11 +679,11 @@ class TestIOStreamMixin(TestReadWriteMixin): self.assertTrue(isinstance(stream.error, socket.error), stream.error) if sys.platform != 'cygwin': - _ERRNO_CONNREFUSED = (errno.ECONNREFUSED,) + _ERRNO_CONNREFUSED = [errno.ECONNREFUSED] if hasattr(errno, "WSAECONNREFUSED"): - _ERRNO_CONNREFUSED += (errno.WSAECONNREFUSED,) + _ERRNO_CONNREFUSED.append(errno.WSAECONNREFUSED) # type: ignore # cygwin's errnos don't match those used on native windows python - self.assertTrue(stream.error.args[0] in _ERRNO_CONNREFUSED) + self.assertTrue(stream.error.args[0] in _ERRNO_CONNREFUSED) # type: ignore @gen_test def test_gaierror(self): @@ -849,7 +849,7 @@ class TestIOStreamStartTLS(AsyncTestCase): super(TestIOStreamStartTLS, self).setUp() self.listener, self.port = bind_unused_port() self.server_stream = None - self.server_accepted = Future() + self.server_accepted = Future() # type: Future[None] netutil.add_accept_handler(self.listener, self.accept) self.client_stream = IOStream(socket.socket()) self.io_loop.add_future(self.client_stream.connect( @@ -969,7 +969,7 @@ class WaitForHandshakeTest(AsyncTestCase): @gen_test def test_wait_for_handshake_future(self): test = self - handshake_future = Future() + handshake_future = Future() # type: Future[None] class TestServer(TCPServer): def handle_stream(self, stream, address): @@ -987,7 +987,7 @@ class WaitForHandshakeTest(AsyncTestCase): @gen_test def test_wait_for_handshake_already_waiting_error(self): test = self - handshake_future = Future() + handshake_future = Future() # type: Future[None] class TestServer(TCPServer): @gen.coroutine @@ -1003,7 +1003,7 @@ class WaitForHandshakeTest(AsyncTestCase): @gen_test def test_wait_for_handshake_already_connected(self): - handshake_future = Future() + handshake_future = Future() # type: Future[None] class TestServer(TCPServer): @gen.coroutine