from ..concurrency.base import ConcurrencyBackend
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..models import AsyncRequest, AsyncResponse
+from ..utils import MessageLoggerASGIMiddleware, get_logger
from .base import AsyncDispatcher
+logger = get_logger(__name__)
+
class ASGIDispatch(AsyncDispatcher):
"""
A custom dispatcher that handles sending requests directly to an ASGI app.
- The simplest way to use this functionality is to use the `app`argument.
+ The simplest way to use this functionality is to use the `app` argument.
This will automatically infer if 'app' is a WSGI or an ASGI application,
and will setup an appropriate dispatch class:
client=("1.2.3.4", 123)
)
client = httpx.Client(dispatch=dispatch)
+ ```
Arguments:
"client": self.client,
"root_path": self.root_path,
}
- app = self.app
+ app = MessageLoggerASGIMiddleware(self.app, logger=logger)
app_exc = None
status_code = None
headers = None
request_stream = request.stream()
async def receive() -> dict:
- nonlocal request_stream
-
try:
body = await request_stream.__anext__()
except StopAsyncIteration:
return {"type": "http.request", "body": body, "more_body": True}
async def send(message: dict) -> None:
- nonlocal status_code, headers, response_started_or_failed
- nonlocal response_body, request
+ nonlocal status_code, headers
if message["type"] == "http.response.start":
status_code = message["status"]
headers = message.get("headers", [])
response_started_or_failed.set()
+
elif message["type"] == "http.response.body":
body = message.get("body", b"")
more_body = message.get("more_body", False)
+
if body and request.method != "HEAD":
await response_body.put(body)
+
if not more_body:
await response_body.mark_as_done()
async def run_app() -> None:
- nonlocal app, scope, receive, send, app_exc, response_body
+ nonlocal app_exc
try:
await app(scope, receive, send)
except Exception as exc:
assert headers is not None
async def on_close() -> None:
- nonlocal response_body
await response_body.drain()
await background.close(app_exc)
if app_exc is not None and self.raise_app_exceptions:
is_known_encoding,
normalize_header_key,
normalize_header_value,
+ obfuscate_sensitive_headers,
parse_header_links,
str_query_param,
)
if self.encoding != "ascii":
encoding_str = f", encoding={self.encoding!r}"
- sensitive_headers = {"authorization", "proxy-authorization"}
- as_list = [
- (k, "[secure]" if k in sensitive_headers else v) for k, v in self.items()
- ]
-
+ as_list = list(obfuscate_sensitive_headers(self.items()))
as_dict = dict(as_list)
- if len(as_dict) == len(as_list):
+
+ no_duplicate_keys = len(as_dict) == len(as_list)
+ if no_duplicate_keys:
return f"{class_name}({as_dict!r}{encoding_str})"
return f"{class_name}({as_list!r}{encoding_str})"
return links
+SENSITIVE_HEADERS = {"authorization", "proxy-authorization"}
+
+
+def obfuscate_sensitive_headers(
+ items: typing.Iterable[typing.Tuple[typing.AnyStr, typing.AnyStr]]
+) -> typing.Iterator[typing.Tuple[typing.AnyStr, typing.AnyStr]]:
+ for k, v in items:
+ if to_str(k.lower()) in SENSITIVE_HEADERS:
+ v = to_bytes_or_str("[secure]", match_type_of=v)
+ yield k, v
+
+
_LOGGER_INITIALIZED = False
return logging.getLogger(name)
+def kv_format(**kwargs: typing.Any) -> str:
+ """Format arguments into a key=value line.
+
+ >>> formatkv(x=1, name="Bob")
+ "x=1 name='Bob'"
+ """
+ return " ".join(f"{key}={value!r}" for key, value in kwargs.items())
+
+
def get_environment_proxies() -> typing.Dict[str, str]:
"""Gets proxy information from the environment"""
return value if isinstance(value, str) else value.decode(encoding)
+def to_bytes_or_str(value: str, match_type_of: typing.AnyStr) -> typing.AnyStr:
+ return value if isinstance(match_type_of, str) else value.encode()
+
+
def unquote(value: str) -> str:
return value[1:-1] if value[0] == value[-1] == '"' else value
if self.end is None:
return timedelta(seconds=perf_counter() - self.start)
return timedelta(seconds=self.end - self.start)
+
+
+ASGI_PLACEHOLDER_FORMAT = {
+ "body": "<{length} bytes>",
+ "bytes": "<{length} bytes>",
+ "text": "<{length} chars>",
+}
+
+
+def asgi_message_with_placeholders(message: dict) -> dict:
+ """
+ Return an ASGI message, with any body-type content omitted and replaced
+ with a placeholder.
+ """
+ new_message = message.copy()
+
+ for attr in ASGI_PLACEHOLDER_FORMAT:
+ if attr in message:
+ content = message[attr]
+ placeholder = ASGI_PLACEHOLDER_FORMAT[attr].format(length=len(content))
+ new_message[attr] = placeholder
+
+ if "headers" in message:
+ new_message["headers"] = list(obfuscate_sensitive_headers(message["headers"]))
+
+ return new_message
+
+
+class MessageLoggerASGIMiddleware:
+ def __init__(self, app: typing.Callable, logger: logging.Logger) -> None:
+ self.app = app
+ self.logger = logger
+
+ async def __call__(
+ self, scope: dict, receive: typing.Callable, send: typing.Callable
+ ) -> None:
+ async def inner_receive() -> dict:
+ message = await receive()
+ logged_message = asgi_message_with_placeholders(message)
+ self.logger.debug(f"sent {kv_format(**logged_message)}")
+ return message
+
+ async def inner_send(message: dict) -> None:
+ logged_message = asgi_message_with_placeholders(message)
+ self.logger.debug(f"received {kv_format(**logged_message)}")
+ await send(message)
+
+ logged_scope = dict(scope)
+ if "headers" in scope:
+ logged_scope["headers"] = list(
+ obfuscate_sensitive_headers(scope["headers"])
+ )
+ self.logger.debug(f"started {kv_format(**logged_scope)}")
+
+ try:
+ await self.app(scope, inner_receive, inner_send)
+ except BaseException as exc:
+ self.logger.debug("raised_exception")
+ raise exc from None
+ else:
+ self.logger.debug("completed")
+import pytest
+
import httpx
h = httpx.Headers([("vary", "a, b"), ("vary", "c")])
h.getlist("Vary") == ["a", "b", "c"]
+
+
+@pytest.mark.parametrize("header", ["authorization", "proxy-authorization"])
+def test_sensitive_headers(header):
+ """
+ Some headers should be obfuscated because they contain sensitive data.
+ """
+ value = "s3kr3t"
+ h = httpx.Headers({header: value})
+ assert repr(h) == "Headers({'%s': '[secure]'})" % header
get_environment_proxies,
get_netrc_login,
guess_json_utf,
+ obfuscate_sensitive_headers,
parse_header_links,
)
os.environ.update(environment)
assert get_environment_proxies() == proxies
+
+
+@pytest.mark.parametrize(
+ "headers, output",
+ [
+ ([("content-type", "text/html")], [("content-type", "text/html")]),
+ ([("authorization", "s3kr3t")], [("authorization", "[secure]")]),
+ ([("proxy-authorization", "s3kr3t")], [("proxy-authorization", "[secure]")]),
+ ],
+)
+def test_obfuscate_sensitive_headers(headers, output):
+ bytes_headers = [(k.encode(), v.encode()) for k, v in headers]
+ bytes_output = [(k.encode(), v.encode()) for k, v in output]
+ assert list(obfuscate_sensitive_headers(headers)) == output
+ assert list(obfuscate_sensitive_headers(bytes_headers)) == bytes_output