"""
import abc
+import asyncio
import base64
import hashlib
import os
self.stream = None # type: Optional[IOStream]
self._on_close_called = False
- def get(self, *args: Any, **kwargs: Any) -> None:
+ async def get(self, *args: Any, **kwargs: Any) -> None:
self.open_args = args
self.open_kwargs = kwargs
self.ws_connection = self.get_websocket_protocol()
if self.ws_connection:
- self.ws_connection.accept_connection(self)
+ await self.ws_connection.accept_connection(self)
else:
self.set_status(426, "Upgrade Required")
self.set_header("Sec-WebSocket-Version", "7, 8, 13")
- self.finish()
stream = None
raise NotImplementedError()
@abc.abstractmethod
- def accept_connection(self, handler: WebSocketHandler) -> None:
+ async def accept_connection(self, handler: WebSocketHandler) -> None:
raise NotImplementedError()
@abc.abstractmethod
self._masked_frame = None
self._frame_mask = None # type: Optional[bytes]
self._frame_length = None
- self._fragmented_message_buffer = None
+ self._fragmented_message_buffer = None # type: Optional[bytes]
self._fragmented_message_opcode = None
self._waiting = None # type: object
self._compression_options = params.compression_options
def selected_subprotocol(self, value: Optional[str]) -> None:
self._selected_subprotocol = value
- def accept_connection(self, handler: WebSocketHandler) -> None:
+ async def accept_connection(self, handler: WebSocketHandler) -> None:
try:
self._handle_websocket_headers(handler)
except ValueError:
return
try:
- self._accept_connection(handler)
+ await self._accept_connection(handler)
+ except asyncio.CancelledError:
+ self._abort()
+ return
except ValueError:
gen_log.debug("Malformed WebSocket request received", exc_info=True)
self._abort()
cast(str, handler.request.headers.get("Sec-Websocket-Key"))
)
- @gen.coroutine
- def _accept_connection(
- self, handler: WebSocketHandler
- ) -> Generator[Any, Any, None]:
+ async def _accept_connection(self, handler: WebSocketHandler) -> None:
subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol")
if subprotocol_header:
subprotocols = [s.strip() for s in subprotocol_header.split(",")]
handler.open, *handler.open_args, **handler.open_kwargs
)
if open_result is not None:
- yield open_result
- yield self._receive_frame_loop()
+ await open_result
+ await self._receive_frame_loop()
def _parse_extensions_header(
self, headers: httputil.HTTPHeaders