]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add debug logs to ASGIDispatch (#371)
authorFlorimond Manca <florimond.manca@gmail.com>
Mon, 23 Sep 2019 20:49:24 +0000 (22:49 +0200)
committerGitHub <noreply@github.com>
Mon, 23 Sep 2019 20:49:24 +0000 (22:49 +0200)
* Add debug logs to ASGIDispatch

* Tidy up ASGIDispatch

* Log entire scope and ASGI messages

* Obfuscate sensitive headers using common utility

* Update utils.py

httpx/dispatch/asgi.py
httpx/models.py
httpx/utils.py
tests/models/test_headers.py
tests/test_utils.py

index 0633ea8f007f0ac74b82e0450157d372e4cdd352..72fef53faddba0dd46aed46be370f2674b52ece9 100644 (file)
@@ -4,14 +4,17 @@ from ..concurrency.asyncio import AsyncioBackend
 from ..concurrency.base import ConcurrencyBackend
 from ..config import CertTypes, TimeoutTypes, VerifyTypes
 from ..models import AsyncRequest, AsyncResponse
+from ..utils import MessageLoggerASGIMiddleware, get_logger
 from .base import AsyncDispatcher
 
+logger = get_logger(__name__)
+
 
 class ASGIDispatch(AsyncDispatcher):
     """
     A custom dispatcher that handles sending requests directly to an ASGI app.
 
-    The simplest way to use this functionality is to use the `app`argument.
+    The simplest way to use this functionality is to use the `app` argument.
     This will automatically infer if 'app' is a WSGI or an ASGI application,
     and will setup an appropriate dispatch class:
 
@@ -30,6 +33,7 @@ class ASGIDispatch(AsyncDispatcher):
         client=("1.2.3.4", 123)
     )
     client = httpx.Client(dispatch=dispatch)
+    ```
 
     Arguments:
 
@@ -77,7 +81,7 @@ class ASGIDispatch(AsyncDispatcher):
             "client": self.client,
             "root_path": self.root_path,
         }
-        app = self.app
+        app = MessageLoggerASGIMiddleware(self.app, logger=logger)
         app_exc = None
         status_code = None
         headers = None
@@ -86,8 +90,6 @@ class ASGIDispatch(AsyncDispatcher):
         request_stream = request.stream()
 
         async def receive() -> dict:
-            nonlocal request_stream
-
             try:
                 body = await request_stream.__anext__()
             except StopAsyncIteration:
@@ -95,23 +97,25 @@ class ASGIDispatch(AsyncDispatcher):
             return {"type": "http.request", "body": body, "more_body": True}
 
         async def send(message: dict) -> None:
-            nonlocal status_code, headers, response_started_or_failed
-            nonlocal response_body, request
+            nonlocal status_code, headers
 
             if message["type"] == "http.response.start":
                 status_code = message["status"]
                 headers = message.get("headers", [])
                 response_started_or_failed.set()
+
             elif message["type"] == "http.response.body":
                 body = message.get("body", b"")
                 more_body = message.get("more_body", False)
+
                 if body and request.method != "HEAD":
                     await response_body.put(body)
+
                 if not more_body:
                     await response_body.mark_as_done()
 
         async def run_app() -> None:
-            nonlocal app, scope, receive, send, app_exc, response_body
+            nonlocal app_exc
             try:
                 await app(scope, receive, send)
             except Exception as exc:
@@ -137,7 +141,6 @@ class ASGIDispatch(AsyncDispatcher):
         assert headers is not None
 
         async def on_close() -> None:
-            nonlocal response_body
             await response_body.drain()
             await background.close(app_exc)
             if app_exc is not None and self.raise_app_exceptions:
index 261eefb24ad6e0729af64d63e390d84d3e6abb73..7b256684555c8976f8a64ae5c95c1ca4a8c925d8 100644 (file)
@@ -36,6 +36,7 @@ from .utils import (
     is_known_encoding,
     normalize_header_key,
     normalize_header_value,
+    obfuscate_sensitive_headers,
     parse_header_links,
     str_query_param,
 )
@@ -554,13 +555,11 @@ class Headers(typing.MutableMapping[str, str]):
         if self.encoding != "ascii":
             encoding_str = f", encoding={self.encoding!r}"
 
-        sensitive_headers = {"authorization", "proxy-authorization"}
-        as_list = [
-            (k, "[secure]" if k in sensitive_headers else v) for k, v in self.items()
-        ]
-
+        as_list = list(obfuscate_sensitive_headers(self.items()))
         as_dict = dict(as_list)
-        if len(as_dict) == len(as_list):
+
+        no_duplicate_keys = len(as_dict) == len(as_list)
+        if no_duplicate_keys:
             return f"{class_name}({as_dict!r}{encoding_str})"
         return f"{class_name}({as_list!r}{encoding_str})"
 
index cff81849303d53aac2d50bc718f11f8a433c4365..c8fcb1e66893333c161254568405147e51544023 100644 (file)
@@ -160,6 +160,18 @@ def parse_header_links(value: str) -> typing.List[typing.Dict[str, str]]:
     return links
 
 
+SENSITIVE_HEADERS = {"authorization", "proxy-authorization"}
+
+
+def obfuscate_sensitive_headers(
+    items: typing.Iterable[typing.Tuple[typing.AnyStr, typing.AnyStr]]
+) -> typing.Iterator[typing.Tuple[typing.AnyStr, typing.AnyStr]]:
+    for k, v in items:
+        if to_str(k.lower()) in SENSITIVE_HEADERS:
+            v = to_bytes_or_str("[secure]", match_type_of=v)
+        yield k, v
+
+
 _LOGGER_INITIALIZED = False
 
 
@@ -187,6 +199,15 @@ def get_logger(name: str) -> logging.Logger:
     return logging.getLogger(name)
 
 
+def kv_format(**kwargs: typing.Any) -> str:
+    """Format arguments into a key=value line.
+
+    >>> formatkv(x=1, name="Bob")
+    "x=1 name='Bob'"
+    """
+    return " ".join(f"{key}={value!r}" for key, value in kwargs.items())
+
+
 def get_environment_proxies() -> typing.Dict[str, str]:
     """Gets proxy information from the environment"""
 
@@ -227,6 +248,10 @@ def to_str(value: typing.Union[str, bytes], encoding: str = "utf-8") -> str:
     return value if isinstance(value, str) else value.decode(encoding)
 
 
+def to_bytes_or_str(value: str, match_type_of: typing.AnyStr) -> typing.AnyStr:
+    return value if isinstance(match_type_of, str) else value.encode()
+
+
 def unquote(value: str) -> str:
     return value[1:-1] if value[0] == value[-1] == '"' else value
 
@@ -253,3 +278,64 @@ class ElapsedTimer:
         if self.end is None:
             return timedelta(seconds=perf_counter() - self.start)
         return timedelta(seconds=self.end - self.start)
+
+
+ASGI_PLACEHOLDER_FORMAT = {
+    "body": "<{length} bytes>",
+    "bytes": "<{length} bytes>",
+    "text": "<{length} chars>",
+}
+
+
+def asgi_message_with_placeholders(message: dict) -> dict:
+    """
+    Return an ASGI message, with any body-type content omitted and replaced
+    with a placeholder.
+    """
+    new_message = message.copy()
+
+    for attr in ASGI_PLACEHOLDER_FORMAT:
+        if attr in message:
+            content = message[attr]
+            placeholder = ASGI_PLACEHOLDER_FORMAT[attr].format(length=len(content))
+            new_message[attr] = placeholder
+
+    if "headers" in message:
+        new_message["headers"] = list(obfuscate_sensitive_headers(message["headers"]))
+
+    return new_message
+
+
+class MessageLoggerASGIMiddleware:
+    def __init__(self, app: typing.Callable, logger: logging.Logger) -> None:
+        self.app = app
+        self.logger = logger
+
+    async def __call__(
+        self, scope: dict, receive: typing.Callable, send: typing.Callable
+    ) -> None:
+        async def inner_receive() -> dict:
+            message = await receive()
+            logged_message = asgi_message_with_placeholders(message)
+            self.logger.debug(f"sent {kv_format(**logged_message)}")
+            return message
+
+        async def inner_send(message: dict) -> None:
+            logged_message = asgi_message_with_placeholders(message)
+            self.logger.debug(f"received {kv_format(**logged_message)}")
+            await send(message)
+
+        logged_scope = dict(scope)
+        if "headers" in scope:
+            logged_scope["headers"] = list(
+                obfuscate_sensitive_headers(scope["headers"])
+            )
+        self.logger.debug(f"started {kv_format(**logged_scope)}")
+
+        try:
+            await self.app(scope, inner_receive, inner_send)
+        except BaseException as exc:
+            self.logger.debug("raised_exception")
+            raise exc from None
+        else:
+            self.logger.debug("completed")
index 329fb52065581c83871fedd268877d0f5bd016af..959d782c72ad02ac11a593e2276c835514908a8d 100644 (file)
@@ -1,3 +1,5 @@
+import pytest
+
 import httpx
 
 
@@ -151,3 +153,13 @@ def test_multiple_headers():
 
     h = httpx.Headers([("vary", "a, b"), ("vary", "c")])
     h.getlist("Vary") == ["a", "b", "c"]
+
+
+@pytest.mark.parametrize("header", ["authorization", "proxy-authorization"])
+def test_sensitive_headers(header):
+    """
+    Some headers should be obfuscated because they contain sensitive data.
+    """
+    value = "s3kr3t"
+    h = httpx.Headers({header: value})
+    assert repr(h) == "Headers({'%s': '[secure]'})" % header
index d9b86f63cd2b917fb2450d5bea82f7d0ee8813e7..9e8169c7f9e22143645dd8dd1693f08c5e65af39 100644 (file)
@@ -12,6 +12,7 @@ from httpx.utils import (
     get_environment_proxies,
     get_netrc_login,
     guess_json_utf,
+    obfuscate_sensitive_headers,
     parse_header_links,
 )
 
@@ -192,3 +193,18 @@ def test_get_environment_proxies(environment, proxies):
     os.environ.update(environment)
 
     assert get_environment_proxies() == proxies
+
+
+@pytest.mark.parametrize(
+    "headers, output",
+    [
+        ([("content-type", "text/html")], [("content-type", "text/html")]),
+        ([("authorization", "s3kr3t")], [("authorization", "[secure]")]),
+        ([("proxy-authorization", "s3kr3t")], [("proxy-authorization", "[secure]")]),
+    ],
+)
+def test_obfuscate_sensitive_headers(headers, output):
+    bytes_headers = [(k.encode(), v.encode()) for k, v in headers]
+    bytes_output = [(k.encode(), v.encode()) for k, v in output]
+    assert list(obfuscate_sensitive_headers(headers)) == output
+    assert list(obfuscate_sensitive_headers(bytes_headers)) == bytes_output