base_url: URLTypes = None,
dispatch: typing.Union[AsyncDispatcher, Dispatcher] = None,
app: typing.Callable = None,
+ raise_app_exceptions: bool = True,
backend: ConcurrencyBackend = None,
):
if backend is None:
param_count = len(inspect.signature(app).parameters)
assert param_count in (2, 3)
if param_count == 2:
- dispatch = WSGIDispatch(app=app)
+ dispatch = WSGIDispatch(
+ app=app, raise_app_exceptions=raise_app_exceptions
+ )
else:
- dispatch = ASGIDispatch(app=app)
+ dispatch = ASGIDispatch(
+ app=app, raise_app_exceptions=raise_app_exceptions
+ )
if dispatch is None:
async_dispatch = ConnectionPool(
If the request data is an bytes iterator then return an async bytes
iterator onto the request data.
"""
- if data is None or isinstance(data, (bytes, dict)):
+ if data is None or isinstance(data, (str, bytes, dict)):
return data
# Coerce an iterator into an async iterator, with each item in the
SUPPORTED_DECODERS = {
"identity": IdentityDecoder,
- "deflate": DeflateDecoder,
"gzip": GZipDecoder,
+ "deflate": DeflateDecoder,
"br": BrotliDecoder,
}
def __init__(
self,
app: typing.Callable,
+ raise_app_exceptions: bool = True,
root_path: str = "",
client: typing.Tuple[str, int] = ("127.0.0.1", 123),
) -> None:
self.app = app
+ self.raise_app_exceptions = raise_app_exceptions
self.root_path = root_path
self.client = client
scope = {
"type": "http",
"asgi": {"version": "3.0"},
+ "http_version": "1.1",
"method": request.method,
"headers": request.headers.raw,
"scheme": request.url.scheme,
"path": request.url.path,
- "query": request.url.query.encode("ascii"),
+ "query_string": request.url.query.encode("ascii"),
"server": request.url.host,
"client": self.client,
"root_path": self.root_path,
return {"type": "http.request", "body": body, "more_body": True}
async def send(message: dict) -> None:
- nonlocal status_code, headers, response_started, response_body
+ nonlocal status_code, headers, response_started, response_body, request
if message["type"] == "http.response.start":
status_code = message["status"]
elif message["type"] == "http.response.body":
body = message.get("body", b"")
more_body = message.get("more_body", False)
- if body:
+ if body and request.method != "HEAD":
await response_body.put(body)
if not more_body:
await response_body.done()
async def run_app() -> None:
nonlocal app, scope, receive, send, app_exc, response_body
-
try:
await app(scope, receive, send)
except Exception as exc:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
- if app_exc is not None:
+ if app_exc is not None and self.raise_app_exceptions:
raise app_exc
- assert response_started.is_set, "application did not return a response."
+ assert response_started.is_set(), "application did not return a response."
assert status_code is not None
assert headers is not None
async def on_close() -> None:
- nonlocal app_task
+ nonlocal app_task, response_body
+ await response_body.drain()
await app_task
- if app_exc is not None:
+ if app_exc is not None and self.raise_app_exceptions:
raise app_exc
return AsyncResponse(
assert isinstance(data, bytes)
yield data
+ async def drain(self) -> None:
+ """
+ Drain any remaining body, in order to allow any blocked `put()` calls
+ to complete.
+ """
+ async for chunk in self.iterate():
+ pass # pragma: no cover
+
async def put(self, data: bytes) -> None:
"""
Used by the server to add data to the response body.
method = request.method.encode("ascii")
target = request.url.full_path.encode("ascii")
headers = request.headers.raw
- if "Host" not in request.headers:
- host = request.url.authority.encode("ascii")
- headers = [(b"host", host)] + headers
event = h11.Request(method=method, target=target, headers=headers)
await self._send_event(event, timeout)
(b":authority", request.url.authority.encode("ascii")),
(b":scheme", request.url.scheme.encode("ascii")),
(b":path", request.url.full_path.encode("ascii")),
- ] + request.headers.raw
+ ] + [(k, v) for k, v in request.headers.raw if k != b"host"]
self.h2_state.send_headers(stream_id, headers)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
def __init__(
self,
app: typing.Callable,
+ raise_app_exceptions: bool = True,
script_name: str = "",
remote_addr: str = "127.0.0.1",
) -> None:
self.app = app
+ self.raise_app_exceptions = raise_app_exceptions
self.script_name = script_name
self.remote_addr = remote_addr
assert seen_status is not None
assert seen_response_headers is not None
- if seen_exc_info:
+ if seen_exc_info and self.raise_app_exceptions:
raise seen_exc_info[1]
return Response(
typing.Callable[["AsyncRequest"], "AsyncRequest"],
]
-AsyncRequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
+AsyncRequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
-RequestData = typing.Union[dict, bytes, typing.Iterator[bytes]]
+RequestData = typing.Union[dict, str, bytes, typing.Iterator[bytes]]
RequestFiles = typing.Dict[
str,
auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
+ has_host = "host" in self.headers
has_user_agent = "user-agent" in self.headers
has_accept = "accept" in self.headers
has_content_length = (
)
has_accept_encoding = "accept-encoding" in self.headers
+ if not has_host:
+ auto_headers.append((b"host", self.url.authority.encode("ascii")))
if not has_user_agent:
auto_headers.append((b"user-agent", b"http3"))
if not has_accept:
self.content = content
if content_type:
self.headers["Content-Type"] = content_type
- elif isinstance(data, bytes):
+ elif isinstance(data, (str, bytes)):
+ data = data.encode("utf-8") if isinstance(data, str) else data
self.is_streaming = False
self.content = data
else:
self.content = content
if content_type:
self.headers["Content-Type"] = content_type
- elif isinstance(data, bytes):
+ elif isinstance(data, (str, bytes)):
+ data = data.encode("utf-8") if isinstance(data, str) else data
self.is_streaming = False
self.content = data
else:
import os
import typing
from io import BytesIO
-from urllib.parse import quote_plus
+from urllib.parse import quote
class Field:
self.value = value
def render_headers(self) -> bytes:
- name = quote_plus(self.name, encoding="utf-8").encode("ascii")
+ name = quote(self.name, encoding="utf-8").encode("ascii")
return b"".join(
[b'Content-Disposition: form-data; name="', name, b'"\r\n' b"\r\n"]
)
def render_data(self) -> bytes:
- return quote_plus(self.value, encoding="utf-8").encode("ascii")
+ return self.value.encode("utf-8")
class FileField(Field):
return mimetypes.guess_type(self.filename)[0] or "application/octet-stream"
def render_headers(self) -> bytes:
- name = quote_plus(self.name, encoding="utf-8").encode("ascii")
- filename = quote_plus(self.filename, encoding="utf-8").encode("ascii")
+ name = quote(self.name, encoding="utf-8").encode("ascii")
+ filename = quote(self.filename, encoding="utf-8").encode("ascii")
content_type = self.content_type.encode("ascii")
return b"".join(
[