class SocketStream(BaseSocketStream):
def __init__(
- self,
- stream_reader: asyncio.StreamReader,
- stream_writer: asyncio.StreamWriter,
- timeout: Timeout,
+ self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter,
):
self.stream_reader = stream_reader
self.stream_writer = stream_writer
- self.timeout = timeout
self.read_lock = asyncio.Lock()
self._inner: typing.Optional[SocketStream] = None
transport=transport, protocol=protocol, reader=stream_reader, loop=loop
)
- ssl_stream = SocketStream(stream_reader, stream_writer, self.timeout)
+ ssl_stream = SocketStream(stream_reader, stream_writer)
# When we return a new SocketStream with new StreamReader/StreamWriter instances
# we need to keep references to the old StreamReader/StreamWriter so that they
# are not garbage collected and closed while we're still using them.
ident = ssl_object.selected_alpn_protocol()
return "HTTP/2" if ident == "h2" else "HTTP/1.1"
- async def read(
- self, n: int, timeout: Timeout = None, flag: TimeoutFlag = None
- ) -> bytes:
- if timeout is None:
- timeout = self.timeout
-
+ async def read(self, n: int, timeout: Timeout, flag: TimeoutFlag = None) -> bytes:
while True:
# Check our flag at the first possible moment, and use a fine
# grained retry loop if we're not yet in read-timeout mode.
return data
async def write(
- self, data: bytes, timeout: Timeout = None, flag: TimeoutFlag = None
+ self, data: bytes, timeout: Timeout, flag: TimeoutFlag = None
) -> None:
if not data:
return
- if timeout is None:
- timeout = self.timeout
-
self.stream_writer.write(data)
while True:
try:
except asyncio.TimeoutError:
raise ConnectTimeout()
- return SocketStream(
- stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
- )
+ return SocketStream(stream_reader=stream_reader, stream_writer=stream_writer)
async def open_uds_stream(
self,
except asyncio.TimeoutError:
raise ConnectTimeout()
- return SocketStream(
- stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
- )
+ return SocketStream(stream_reader=stream_reader, stream_writer=stream_writer)
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
class SocketStream(BaseSocketStream):
def __init__(
- self, stream: typing.Union[trio.SocketStream, trio.SSLStream], timeout: Timeout,
+ self, stream: typing.Union[trio.SocketStream, trio.SSLStream],
) -> None:
self.stream = stream
- self.timeout = timeout
self.read_lock = trio.Lock()
self.write_lock = trio.Lock()
if cancel_scope.cancelled_caught:
raise ConnectTimeout()
- return SocketStream(ssl_stream, self.timeout)
+ return SocketStream(ssl_stream)
def get_http_version(self) -> str:
if not isinstance(self.stream, trio.SSLStream):
ident = self.stream.selected_alpn_protocol()
return "HTTP/2" if ident == "h2" else "HTTP/1.1"
- async def read(
- self, n: int, timeout: Timeout = None, flag: TimeoutFlag = None
- ) -> bytes:
- if timeout is None:
- timeout = self.timeout
-
+ async def read(self, n: int, timeout: Timeout, flag: TimeoutFlag = None) -> bytes:
while True:
# Check our flag at the first possible moment, and use a fine
# grained retry loop if we're not yet in read-timeout mode.
return stream.socket.is_readable()
async def write(
- self, data: bytes, timeout: Timeout = None, flag: TimeoutFlag = None
+ self, data: bytes, timeout: Timeout, flag: TimeoutFlag = None
) -> None:
if not data:
return
- if timeout is None:
- timeout = self.timeout
-
write_timeout = _or_inf(timeout.write_timeout)
while True:
if cancel_scope.cancelled_caught:
raise ConnectTimeout()
- return SocketStream(stream=stream, timeout=timeout)
+ return SocketStream(stream=stream)
async def open_uds_stream(
self,
if cancel_scope.cancelled_caught:
raise ConnectTimeout()
- return SocketStream(stream=stream, timeout=timeout)
+ return SocketStream(stream=stream)
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None
-async def read_response(stream, timeout: float, should_contain: bytes) -> bytes:
+async def read_response(stream, timeout: Timeout, should_contain: bytes) -> bytes:
# stream.read() only gives us *up to* as much data as we ask for. In order to
# cleanly close the stream, we must read until the end of the HTTP response.
response = b""
assert stream.is_connection_dropped() is False
assert get_cipher(stream) is not None
- await stream.write(b"GET / HTTP/1.1\r\n\r\n")
+ await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
response = await read_response(stream, timeout, should_contain=b"Hello, world")
assert response.startswith(b"HTTP/1.1 200 OK\r\n")
assert stream.is_connection_dropped() is False
assert get_cipher(stream) is not None
- await stream.write(b"GET / HTTP/1.1\r\n\r\n")
+ await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
response = await read_response(stream, timeout, should_contain=b"Hello, world")
assert response.startswith(b"HTTP/1.1 200 OK\r\n")
stream = await backend.open_tcp_stream(
server.url.host, server.url.port, ssl_context=None, timeout=Timeout(5)
)
+ timeout = Timeout(5)
try:
- await stream.write(b"GET / HTTP/1.1\r\n\r\n")
+ await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
await run_concurrently(
- backend, lambda: stream.read(10), lambda: stream.read(10)
+ backend, lambda: stream.read(10, timeout), lambda: stream.read(10, timeout)
)
finally:
await stream.close()