]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Wrap network errors in HTTPX-specific exceptions (#707)
authorFlorimond Manca <florimond.manca@gmail.com>
Tue, 7 Jan 2020 10:01:11 +0000 (11:01 +0100)
committerTom Christie <tom@tomchristie.com>
Tue, 7 Jan 2020 10:01:11 +0000 (10:01 +0000)
httpx/backends/asyncio.py
httpx/backends/trio.py
httpx/exceptions.py
httpx/utils.py

index 74d2b0d4f854424131c49c28584811306357934e..8d1025748bd07eaeea44c689ce19c9f881e926bf 100644 (file)
@@ -4,6 +4,7 @@ import typing
 
 from ..config import Timeout
 from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
+from ..utils import as_network_error
 from .base import BaseLock, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
 
 SSL_MONKEY_PATCH_APPLIED = False
@@ -125,9 +126,10 @@ class SocketStream(BaseSocketStream):
     async def read(self, n: int, timeout: Timeout) -> bytes:
         try:
             async with self.read_lock:
-                return await asyncio.wait_for(
-                    self.stream_reader.read(n), timeout.read_timeout
-                )
+                with as_network_error(OSError):
+                    return await asyncio.wait_for(
+                        self.stream_reader.read(n), timeout.read_timeout
+                    )
         except asyncio.TimeoutError:
             raise ReadTimeout() from None
 
@@ -137,10 +139,11 @@ class SocketStream(BaseSocketStream):
 
         try:
             async with self.write_lock:
-                self.stream_writer.write(data)
-                return await asyncio.wait_for(
-                    self.stream_writer.drain(), timeout.write_timeout
-                )
+                with as_network_error(OSError):
+                    self.stream_writer.write(data)
+                    return await asyncio.wait_for(
+                        self.stream_writer.drain(), timeout.write_timeout
+                    )
         except asyncio.TimeoutError:
             raise WriteTimeout() from None
 
@@ -170,7 +173,8 @@ class SocketStream(BaseSocketStream):
         # stream, meaning that at best it will happen during the next event loop
         # iteration, and at worst asyncio will take care of it on program exit.
         async with self.write_lock:
-            self.stream_writer.close()
+            with as_network_error(OSError):
+                self.stream_writer.close()
 
 
 class AsyncioBackend(ConcurrencyBackend):
@@ -189,10 +193,11 @@ class AsyncioBackend(ConcurrencyBackend):
         timeout: Timeout,
     ) -> SocketStream:
         try:
-            stream_reader, stream_writer = await asyncio.wait_for(  # type: ignore
-                asyncio.open_connection(hostname, port, ssl=ssl_context),
-                timeout.connect_timeout,
-            )
+            with as_network_error(OSError):
+                stream_reader, stream_writer = await asyncio.wait_for(  # type: ignore
+                    asyncio.open_connection(hostname, port, ssl=ssl_context),
+                    timeout.connect_timeout,
+                )
         except asyncio.TimeoutError:
             raise ConnectTimeout()
 
@@ -208,12 +213,13 @@ class AsyncioBackend(ConcurrencyBackend):
         server_hostname = hostname if ssl_context else None
 
         try:
-            stream_reader, stream_writer = await asyncio.wait_for(  # type: ignore
-                asyncio.open_unix_connection(
-                    path, ssl=ssl_context, server_hostname=server_hostname
-                ),
-                timeout.connect_timeout,
-            )
+            with as_network_error(OSError):
+                stream_reader, stream_writer = await asyncio.wait_for(  # type: ignore
+                    asyncio.open_unix_connection(
+                        path, ssl=ssl_context, server_hostname=server_hostname
+                    ),
+                    timeout.connect_timeout,
+                )
         except asyncio.TimeoutError:
             raise ConnectTimeout()
 
index 979aa450b7899f10ec7f09e6e05f6ef91f51a32c..33e93e9677e21f5b51cef933d54a88d7702df55b 100644 (file)
@@ -5,6 +5,7 @@ import trio
 
 from ..config import Timeout
 from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
+from ..utils import as_network_error
 from .base import BaseLock, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
 
 
@@ -29,7 +30,8 @@ class SocketStream(BaseSocketStream):
         )
 
         with trio.move_on_after(connect_timeout):
-            await ssl_stream.do_handshake()
+            with as_network_error(trio.BrokenResourceError):
+                await ssl_stream.do_handshake()
             return SocketStream(ssl_stream)
 
         raise ConnectTimeout()
@@ -46,7 +48,8 @@ class SocketStream(BaseSocketStream):
 
         with trio.move_on_after(read_timeout):
             async with self.read_lock:
-                return await self.stream.receive_some(max_bytes=n)
+                with as_network_error(trio.BrokenResourceError):
+                    return await self.stream.receive_some(max_bytes=n)
 
         raise ReadTimeout()
 
@@ -58,7 +61,8 @@ class SocketStream(BaseSocketStream):
 
         with trio.move_on_after(write_timeout):
             async with self.write_lock:
-                return await self.stream.send_all(data)
+                with as_network_error(trio.BrokenResourceError):
+                    return await self.stream.send_all(data)
 
         raise WriteTimeout()
 
@@ -93,10 +97,14 @@ class TrioBackend(ConcurrencyBackend):
         connect_timeout = none_as_inf(timeout.connect_timeout)
 
         with trio.move_on_after(connect_timeout):
-            stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port)
+            with as_network_error(OSError):
+                stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port)
+
             if ssl_context is not None:
                 stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
-                await stream.do_handshake()
+                with as_network_error(trio.BrokenResourceError):
+                    await stream.do_handshake()
+
             return SocketStream(stream=stream)
 
         raise ConnectTimeout()
@@ -111,10 +119,14 @@ class TrioBackend(ConcurrencyBackend):
         connect_timeout = none_as_inf(timeout.connect_timeout)
 
         with trio.move_on_after(connect_timeout):
-            stream: trio.SocketStream = await trio.open_unix_socket(path)
+            with as_network_error(OSError):
+                stream: trio.SocketStream = await trio.open_unix_socket(path)
+
             if ssl_context is not None:
                 stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
-                await stream.do_handshake()
+                with as_network_error(trio.BrokenResourceError):
+                    await stream.do_handshake()
+
             return SocketStream(stream=stream)
 
         raise ConnectTimeout()
index e1992708f6d7bc887719d51f21547a2315ed03a8..9f2119852c8ac34885a11ad53ea43545773f7bf0 100644 (file)
@@ -71,6 +71,21 @@ class DecodingError(HTTPError):
     """
 
 
+# Network exceptions...
+
+
+class NetworkError(HTTPError):
+    """
+    A failure occurred while trying to access the network.
+    """
+
+
+class ConnectionClosed(NetworkError):
+    """
+    Expected more data from peer, but connection was closed.
+    """
+
+
 # Redirect exceptions...
 
 
@@ -147,12 +162,6 @@ class InvalidURL(HTTPError):
     """
 
 
-class ConnectionClosed(HTTPError):
-    """
-    Expected more data from peer, but connection was closed.
-    """
-
-
 class CookieConflict(HTTPError):
     """
     Attempted to lookup a cookie by name, but multiple cookies existed.
index dfa9af8aaf1c69a7428ba4fa20c0225bc5231bfc..5700a1cc8affba3086e6f06149073c74708338da 100644 (file)
@@ -1,5 +1,6 @@
 import codecs
 import collections
+import contextlib
 import logging
 import netrc
 import os
@@ -12,6 +13,8 @@ from time import perf_counter
 from types import TracebackType
 from urllib.request import getproxies
 
+from .exceptions import NetworkError
+
 if typing.TYPE_CHECKING:  # pragma: no cover
     from .models import PrimitiveData
     from .models import URL
@@ -353,3 +356,14 @@ class ElapsedTimer:
         if self.end is None:
             return timedelta(seconds=perf_counter() - self.start)
         return timedelta(seconds=self.end - self.start)
+
+
+@contextlib.contextmanager
+def as_network_error(*exception_classes: type) -> typing.Iterator[None]:
+    try:
+        yield
+    except BaseException as exc:
+        for cls in exception_classes:
+            if isinstance(exc, cls):
+                raise NetworkError(exc) from exc
+        raise