from ..config import Timeout
from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
+from ..utils import as_network_error
from .base import BaseLock, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
SSL_MONKEY_PATCH_APPLIED = False
async def read(self, n: int, timeout: Timeout) -> bytes:
try:
async with self.read_lock:
- return await asyncio.wait_for(
- self.stream_reader.read(n), timeout.read_timeout
- )
+ with as_network_error(OSError):
+ return await asyncio.wait_for(
+ self.stream_reader.read(n), timeout.read_timeout
+ )
except asyncio.TimeoutError:
raise ReadTimeout() from None
try:
async with self.write_lock:
- self.stream_writer.write(data)
- return await asyncio.wait_for(
- self.stream_writer.drain(), timeout.write_timeout
- )
+ with as_network_error(OSError):
+ self.stream_writer.write(data)
+ return await asyncio.wait_for(
+ self.stream_writer.drain(), timeout.write_timeout
+ )
except asyncio.TimeoutError:
raise WriteTimeout() from None
# stream, meaning that at best it will happen during the next event loop
# iteration, and at worst asyncio will take care of it on program exit.
async with self.write_lock:
- self.stream_writer.close()
+ with as_network_error(OSError):
+ self.stream_writer.close()
class AsyncioBackend(ConcurrencyBackend):
timeout: Timeout,
) -> SocketStream:
try:
- stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
- asyncio.open_connection(hostname, port, ssl=ssl_context),
- timeout.connect_timeout,
- )
+ with as_network_error(OSError):
+ stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
+ asyncio.open_connection(hostname, port, ssl=ssl_context),
+ timeout.connect_timeout,
+ )
except asyncio.TimeoutError:
raise ConnectTimeout()
server_hostname = hostname if ssl_context else None
try:
- stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
- asyncio.open_unix_connection(
- path, ssl=ssl_context, server_hostname=server_hostname
- ),
- timeout.connect_timeout,
- )
+ with as_network_error(OSError):
+ stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
+ asyncio.open_unix_connection(
+ path, ssl=ssl_context, server_hostname=server_hostname
+ ),
+ timeout.connect_timeout,
+ )
except asyncio.TimeoutError:
raise ConnectTimeout()
from ..config import Timeout
from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
+from ..utils import as_network_error
from .base import BaseLock, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
)
with trio.move_on_after(connect_timeout):
- await ssl_stream.do_handshake()
+ with as_network_error(trio.BrokenResourceError):
+ await ssl_stream.do_handshake()
return SocketStream(ssl_stream)
raise ConnectTimeout()
with trio.move_on_after(read_timeout):
async with self.read_lock:
- return await self.stream.receive_some(max_bytes=n)
+ with as_network_error(trio.BrokenResourceError):
+ return await self.stream.receive_some(max_bytes=n)
raise ReadTimeout()
with trio.move_on_after(write_timeout):
async with self.write_lock:
- return await self.stream.send_all(data)
+ with as_network_error(trio.BrokenResourceError):
+ return await self.stream.send_all(data)
raise WriteTimeout()
connect_timeout = none_as_inf(timeout.connect_timeout)
with trio.move_on_after(connect_timeout):
- stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port)
+ with as_network_error(OSError):
+ stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port)
+
if ssl_context is not None:
stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
- await stream.do_handshake()
+ with as_network_error(trio.BrokenResourceError):
+ await stream.do_handshake()
+
return SocketStream(stream=stream)
raise ConnectTimeout()
connect_timeout = none_as_inf(timeout.connect_timeout)
with trio.move_on_after(connect_timeout):
- stream: trio.SocketStream = await trio.open_unix_socket(path)
+ with as_network_error(OSError):
+ stream: trio.SocketStream = await trio.open_unix_socket(path)
+
if ssl_context is not None:
stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
- await stream.do_handshake()
+ with as_network_error(trio.BrokenResourceError):
+ await stream.do_handshake()
+
return SocketStream(stream=stream)
raise ConnectTimeout()
"""
+# Network exceptions...
+
+
+class NetworkError(HTTPError):
+ """
+ A failure occurred while trying to access the network.
+ """
+
+
+class ConnectionClosed(NetworkError):
+ """
+ Expected more data from peer, but connection was closed.
+ """
+
+
# Redirect exceptions...
"""
-class ConnectionClosed(HTTPError):
- """
- Expected more data from peer, but connection was closed.
- """
-
-
class CookieConflict(HTTPError):
"""
Attempted to lookup a cookie by name, but multiple cookies existed.
import codecs
import collections
+import contextlib
import logging
import netrc
import os
from types import TracebackType
from urllib.request import getproxies
+from .exceptions import NetworkError
+
if typing.TYPE_CHECKING: # pragma: no cover
from .models import PrimitiveData
from .models import URL
if self.end is None:
return timedelta(seconds=perf_counter() - self.start)
return timedelta(seconds=self.end - self.start)
+
+
+@contextlib.contextmanager
+def as_network_error(*exception_classes: type) -> typing.Iterator[None]:
+ try:
+ yield
+ except BaseException as exc:
+ for cls in exception_classes:
+ if isinstance(exc, cls):
+ raise NetworkError(exc) from exc
+ raise