from ..concurrency.base import ConcurrencyBackend
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..models import Request, Response
-from ..utils import MessageLoggerASGIMiddleware, get_logger
from .base import Dispatcher
-logger = get_logger(__name__)
-
class ASGIDispatch(Dispatcher):
"""
"client": self.client,
"root_path": self.root_path,
}
- app = MessageLoggerASGIMiddleware(self.app, logger=logger)
status_code = None
headers = None
body_parts = []
response_complete = True
try:
- await app(scope, receive, send)
+ await self.app(scope, receive, send)
except Exception:
if self.raise_app_exceptions or not response_complete:
raise
if self.end is None:
return timedelta(seconds=perf_counter() - self.start)
return timedelta(seconds=self.end - self.start)
-
-
-ASGI_PLACEHOLDER_FORMAT = {
- "body": "<{length} bytes>",
- "bytes": "<{length} bytes>",
- "text": "<{length} chars>",
-}
-
-
-def asgi_message_with_placeholders(message: dict) -> dict:
- """
- Return an ASGI message, with any body-type content omitted and replaced
- with a placeholder.
- """
- new_message = message.copy()
-
- for attr in ASGI_PLACEHOLDER_FORMAT:
- if attr in message:
- content = message[attr]
- placeholder = ASGI_PLACEHOLDER_FORMAT[attr].format(length=len(content))
- new_message[attr] = placeholder
-
- if "headers" in message:
- new_message["headers"] = list(obfuscate_sensitive_headers(message["headers"]))
-
- return new_message
-
-
-class MessageLoggerASGIMiddleware:
- def __init__(self, app: typing.Callable, logger: Logger) -> None:
- self.app = app
- self.logger = logger
-
- async def __call__(
- self, scope: dict, receive: typing.Callable, send: typing.Callable
- ) -> None:
- async def inner_receive() -> dict:
- message = await receive()
- logged_message = asgi_message_with_placeholders(message)
- self.logger.trace(f"sent {kv_format(**logged_message)}")
- return message
-
- async def inner_send(message: dict) -> None:
- logged_message = asgi_message_with_placeholders(message)
- self.logger.trace(f"received {kv_format(**logged_message)}")
- await send(message)
-
- logged_scope = dict(scope)
- if "headers" in scope:
- logged_scope["headers"] = list(
- obfuscate_sensitive_headers(scope["headers"])
- )
- self.logger.trace(f"started {kv_format(**logged_scope)}")
-
- try:
- await self.app(scope, inner_receive, inner_send)
- except BaseException as exc:
- self.logger.trace("raised_exception")
- raise exc from None
- else:
- self.logger.trace("completed")