From: Tom Christie Date: Thu, 27 Jun 2019 15:46:13 +0000 (+0100) Subject: Include host header directly (#109) X-Git-Tag: 0.6.5~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5539b69ac26309bbb755bfb4c766d1e86c80ea1f;p=thirdparty%2Fhttpx.git Include host header directly (#109) * Improve HTTP protocol detection * Include host header when request is instantiated * Add raise_app_exceptions * Tweaks to ASGI dispatching * Linting * Don't quote multipart values * Tweak decoder ordering in header * Allow str data in request bodys --- diff --git a/http3/client.py b/http3/client.py index 8bf20de9..69b96b25 100644 --- a/http3/client.py +++ b/http3/client.py @@ -54,6 +54,7 @@ class BaseClient: base_url: URLTypes = None, dispatch: typing.Union[AsyncDispatcher, Dispatcher] = None, app: typing.Callable = None, + raise_app_exceptions: bool = True, backend: ConcurrencyBackend = None, ): if backend is None: @@ -63,9 +64,13 @@ class BaseClient: param_count = len(inspect.signature(app).parameters) assert param_count in (2, 3) if param_count == 2: - dispatch = WSGIDispatch(app=app) + dispatch = WSGIDispatch( + app=app, raise_app_exceptions=raise_app_exceptions + ) else: - dispatch = ASGIDispatch(app=app) + dispatch = ASGIDispatch( + app=app, raise_app_exceptions=raise_app_exceptions + ) if dispatch is None: async_dispatch = ConnectionPool( @@ -558,7 +563,7 @@ class Client(BaseClient): If the request data is an bytes iterator then return an async bytes iterator onto the request data. """ - if data is None or isinstance(data, (bytes, dict)): + if data is None or isinstance(data, (str, bytes, dict)): return data # Coerce an iterator into an async iterator, with each item in the diff --git a/http3/decoders.py b/http3/decoders.py index b352de25..417f3177 100644 --- a/http3/decoders.py +++ b/http3/decoders.py @@ -134,8 +134,8 @@ class MultiDecoder(Decoder): SUPPORTED_DECODERS = { "identity": IdentityDecoder, - "deflate": DeflateDecoder, "gzip": GZipDecoder, + "deflate": DeflateDecoder, "br": BrotliDecoder, } diff --git a/http3/dispatch/asgi.py b/http3/dispatch/asgi.py index da155a10..384d6d47 100644 --- a/http3/dispatch/asgi.py +++ b/http3/dispatch/asgi.py @@ -35,10 +35,12 @@ class ASGIDispatch(AsyncDispatcher): def __init__( self, app: typing.Callable, + raise_app_exceptions: bool = True, root_path: str = "", client: typing.Tuple[str, int] = ("127.0.0.1", 123), ) -> None: self.app = app + self.raise_app_exceptions = raise_app_exceptions self.root_path = root_path self.client = client @@ -53,11 +55,12 @@ class ASGIDispatch(AsyncDispatcher): scope = { "type": "http", "asgi": {"version": "3.0"}, + "http_version": "1.1", "method": request.method, "headers": request.headers.raw, "scheme": request.url.scheme, "path": request.url.path, - "query": request.url.query.encode("ascii"), + "query_string": request.url.query.encode("ascii"), "server": request.url.host, "client": self.client, "root_path": self.root_path, @@ -80,7 +83,7 @@ class ASGIDispatch(AsyncDispatcher): return {"type": "http.request", "body": body, "more_body": True} async def send(message: dict) -> None: - nonlocal status_code, headers, response_started, response_body + nonlocal status_code, headers, response_started, response_body, request if message["type"] == "http.response.start": status_code = message["status"] @@ -89,14 +92,13 @@ class ASGIDispatch(AsyncDispatcher): elif message["type"] == "http.response.body": body = message.get("body", b"") more_body = message.get("more_body", False) - if body: + if body and request.method != "HEAD": await response_body.put(body) if not more_body: await response_body.done() async def run_app() -> None: nonlocal app, scope, receive, send, app_exc, response_body - try: await app(scope, receive, send) except Exception as exc: @@ -117,17 +119,18 @@ class ASGIDispatch(AsyncDispatcher): await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - if app_exc is not None: + if app_exc is not None and self.raise_app_exceptions: raise app_exc - assert response_started.is_set, "application did not return a response." + assert response_started.is_set(), "application did not return a response." assert status_code is not None assert headers is not None async def on_close() -> None: - nonlocal app_task + nonlocal app_task, response_body + await response_body.drain() await app_task - if app_exc is not None: + if app_exc is not None and self.raise_app_exceptions: raise app_exc return AsyncResponse( @@ -163,6 +166,14 @@ class BodyIterator: assert isinstance(data, bytes) yield data + async def drain(self) -> None: + """ + Drain any remaining body, in order to allow any blocked `put()` calls + to complete. + """ + async for chunk in self.iterate(): + pass # pragma: no cover + async def put(self, data: bytes) -> None: """ Used by the server to add data to the response body. diff --git a/http3/dispatch/http11.py b/http3/dispatch/http11.py index 99865117..b15e1588 100644 --- a/http3/dispatch/http11.py +++ b/http3/dispatch/http11.py @@ -79,9 +79,6 @@ class HTTP11Connection: method = request.method.encode("ascii") target = request.url.full_path.encode("ascii") headers = request.headers.raw - if "Host" not in request.headers: - host = request.url.authority.encode("ascii") - headers = [(b"host", host)] + headers event = h11.Request(method=method, target=target, headers=headers) await self._send_event(event, timeout) diff --git a/http3/dispatch/http2.py b/http3/dispatch/http2.py index 9bd35eaf..56c7728f 100644 --- a/http3/dispatch/http2.py +++ b/http3/dispatch/http2.py @@ -76,7 +76,7 @@ class HTTP2Connection: (b":authority", request.url.authority.encode("ascii")), (b":scheme", request.url.scheme.encode("ascii")), (b":path", request.url.full_path.encode("ascii")), - ] + request.headers.raw + ] + [(k, v) for k, v in request.headers.raw if k != b"host"] self.h2_state.send_headers(stream_id, headers) data_to_send = self.h2_state.data_to_send() await self.writer.write(data_to_send, timeout) diff --git a/http3/dispatch/wsgi.py b/http3/dispatch/wsgi.py index 01c1aae1..f0bf5311 100644 --- a/http3/dispatch/wsgi.py +++ b/http3/dispatch/wsgi.py @@ -35,10 +35,12 @@ class WSGIDispatch(Dispatcher): def __init__( self, app: typing.Callable, + raise_app_exceptions: bool = True, script_name: str = "", remote_addr: str = "127.0.0.1", ) -> None: self.app = app + self.raise_app_exceptions = raise_app_exceptions self.script_name = script_name self.remote_addr = remote_addr @@ -87,7 +89,7 @@ class WSGIDispatch(Dispatcher): assert seen_status is not None assert seen_response_headers is not None - if seen_exc_info: + if seen_exc_info and self.raise_app_exceptions: raise seen_exc_info[1] return Response( diff --git a/http3/models.py b/http3/models.py index 02fc0ddb..1b5bb152 100644 --- a/http3/models.py +++ b/http3/models.py @@ -51,9 +51,9 @@ AuthTypes = typing.Union[ typing.Callable[["AsyncRequest"], "AsyncRequest"], ] -AsyncRequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]] +AsyncRequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]] -RequestData = typing.Union[dict, bytes, typing.Iterator[bytes]] +RequestData = typing.Union[dict, str, bytes, typing.Iterator[bytes]] RequestFiles = typing.Dict[ str, @@ -527,6 +527,7 @@ class BaseRequest: auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] + has_host = "host" in self.headers has_user_agent = "user-agent" in self.headers has_accept = "accept" in self.headers has_content_length = ( @@ -534,6 +535,8 @@ class BaseRequest: ) has_accept_encoding = "accept-encoding" in self.headers + if not has_host: + auto_headers.append((b"host", self.url.authority.encode("ascii"))) if not has_user_agent: auto_headers.append((b"user-agent", b"http3")) if not has_accept: @@ -585,7 +588,8 @@ class AsyncRequest(BaseRequest): self.content = content if content_type: self.headers["Content-Type"] = content_type - elif isinstance(data, bytes): + elif isinstance(data, (str, bytes)): + data = data.encode("utf-8") if isinstance(data, str) else data self.is_streaming = False self.content = data else: @@ -634,7 +638,8 @@ class Request(BaseRequest): self.content = content if content_type: self.headers["Content-Type"] = content_type - elif isinstance(data, bytes): + elif isinstance(data, (str, bytes)): + data = data.encode("utf-8") if isinstance(data, str) else data self.is_streaming = False self.content = data else: diff --git a/http3/multipart.py b/http3/multipart.py index 07be1f13..53fb68fd 100644 --- a/http3/multipart.py +++ b/http3/multipart.py @@ -3,7 +3,7 @@ import mimetypes import os import typing from io import BytesIO -from urllib.parse import quote_plus +from urllib.parse import quote class Field: @@ -20,13 +20,13 @@ class DataField(Field): self.value = value def render_headers(self) -> bytes: - name = quote_plus(self.name, encoding="utf-8").encode("ascii") + name = quote(self.name, encoding="utf-8").encode("ascii") return b"".join( [b'Content-Disposition: form-data; name="', name, b'"\r\n' b"\r\n"] ) def render_data(self) -> bytes: - return quote_plus(self.value, encoding="utf-8").encode("ascii") + return self.value.encode("utf-8") class FileField(Field): @@ -49,8 +49,8 @@ class FileField(Field): return mimetypes.guess_type(self.filename)[0] or "application/octet-stream" def render_headers(self) -> bytes: - name = quote_plus(self.name, encoding="utf-8").encode("ascii") - filename = quote_plus(self.filename, encoding="utf-8").encode("ascii") + name = quote(self.name, encoding="utf-8").encode("ascii") + filename = quote(self.filename, encoding="utf-8").encode("ascii") content_type = self.content_type.encode("ascii") return b"".join( [