]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Include host header directly (#109)
authorTom Christie <tom@tomchristie.com>
Thu, 27 Jun 2019 15:46:13 +0000 (16:46 +0100)
committerGitHub <noreply@github.com>
Thu, 27 Jun 2019 15:46:13 +0000 (16:46 +0100)
* Improve HTTP protocol detection

* Include host header when request is instantiated

* Add raise_app_exceptions

* Tweaks to ASGI dispatching

* Linting

* Don't quote multipart values

* Tweak decoder ordering in header

* Allow str data in request bodys

http3/client.py
http3/decoders.py
http3/dispatch/asgi.py
http3/dispatch/http11.py
http3/dispatch/http2.py
http3/dispatch/wsgi.py
http3/models.py
http3/multipart.py

index 8bf20de97c592724015f861df92a8c67d6eca52c..69b96b252a43514ce338ccefe07ec7a35e7fbd04 100644 (file)
@@ -54,6 +54,7 @@ class BaseClient:
         base_url: URLTypes = None,
         dispatch: typing.Union[AsyncDispatcher, Dispatcher] = None,
         app: typing.Callable = None,
+        raise_app_exceptions: bool = True,
         backend: ConcurrencyBackend = None,
     ):
         if backend is None:
@@ -63,9 +64,13 @@ class BaseClient:
             param_count = len(inspect.signature(app).parameters)
             assert param_count in (2, 3)
             if param_count == 2:
-                dispatch = WSGIDispatch(app=app)
+                dispatch = WSGIDispatch(
+                    app=app, raise_app_exceptions=raise_app_exceptions
+                )
             else:
-                dispatch = ASGIDispatch(app=app)
+                dispatch = ASGIDispatch(
+                    app=app, raise_app_exceptions=raise_app_exceptions
+                )
 
         if dispatch is None:
             async_dispatch = ConnectionPool(
@@ -558,7 +563,7 @@ class Client(BaseClient):
         If the request data is an bytes iterator then return an async bytes
         iterator onto the request data.
         """
-        if data is None or isinstance(data, (bytes, dict)):
+        if data is None or isinstance(data, (str, bytes, dict)):
             return data
 
         # Coerce an iterator into an async iterator, with each item in the
index b352de253fa712110f177d7e65fe3036f83e8f38..417f3177323f0845ecb114873f45413b37801edd 100644 (file)
@@ -134,8 +134,8 @@ class MultiDecoder(Decoder):
 
 SUPPORTED_DECODERS = {
     "identity": IdentityDecoder,
-    "deflate": DeflateDecoder,
     "gzip": GZipDecoder,
+    "deflate": DeflateDecoder,
     "br": BrotliDecoder,
 }
 
index da155a10f64013f7d0ea30bacc9ac282c35ae47b..384d6d47d0a75028ca99582cfda7a53b3a35f058 100644 (file)
@@ -35,10 +35,12 @@ class ASGIDispatch(AsyncDispatcher):
     def __init__(
         self,
         app: typing.Callable,
+        raise_app_exceptions: bool = True,
         root_path: str = "",
         client: typing.Tuple[str, int] = ("127.0.0.1", 123),
     ) -> None:
         self.app = app
+        self.raise_app_exceptions = raise_app_exceptions
         self.root_path = root_path
         self.client = client
 
@@ -53,11 +55,12 @@ class ASGIDispatch(AsyncDispatcher):
         scope = {
             "type": "http",
             "asgi": {"version": "3.0"},
+            "http_version": "1.1",
             "method": request.method,
             "headers": request.headers.raw,
             "scheme": request.url.scheme,
             "path": request.url.path,
-            "query": request.url.query.encode("ascii"),
+            "query_string": request.url.query.encode("ascii"),
             "server": request.url.host,
             "client": self.client,
             "root_path": self.root_path,
@@ -80,7 +83,7 @@ class ASGIDispatch(AsyncDispatcher):
             return {"type": "http.request", "body": body, "more_body": True}
 
         async def send(message: dict) -> None:
-            nonlocal status_code, headers, response_started, response_body
+            nonlocal status_code, headers, response_started, response_body, request
 
             if message["type"] == "http.response.start":
                 status_code = message["status"]
@@ -89,14 +92,13 @@ class ASGIDispatch(AsyncDispatcher):
             elif message["type"] == "http.response.body":
                 body = message.get("body", b"")
                 more_body = message.get("more_body", False)
-                if body:
+                if body and request.method != "HEAD":
                     await response_body.put(body)
                 if not more_body:
                     await response_body.done()
 
         async def run_app() -> None:
             nonlocal app, scope, receive, send, app_exc, response_body
-
             try:
                 await app(scope, receive, send)
             except Exception as exc:
@@ -117,17 +119,18 @@ class ASGIDispatch(AsyncDispatcher):
 
         await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
 
-        if app_exc is not None:
+        if app_exc is not None and self.raise_app_exceptions:
             raise app_exc
 
-        assert response_started.is_set, "application did not return a response."
+        assert response_started.is_set(), "application did not return a response."
         assert status_code is not None
         assert headers is not None
 
         async def on_close() -> None:
-            nonlocal app_task
+            nonlocal app_task, response_body
+            await response_body.drain()
             await app_task
-            if app_exc is not None:
+            if app_exc is not None and self.raise_app_exceptions:
                 raise app_exc
 
         return AsyncResponse(
@@ -163,6 +166,14 @@ class BodyIterator:
             assert isinstance(data, bytes)
             yield data
 
+    async def drain(self) -> None:
+        """
+        Drain any remaining body, in order to allow any blocked `put()` calls
+        to complete.
+        """
+        async for chunk in self.iterate():
+            pass  # pragma: no cover
+
     async def put(self, data: bytes) -> None:
         """
         Used by the server to add data to the response body.
index 99865117d1e4706b44859885951d6aa376942b04..b15e15884ddd2c77e59ea71a98941e904c7e7cad 100644 (file)
@@ -79,9 +79,6 @@ class HTTP11Connection:
         method = request.method.encode("ascii")
         target = request.url.full_path.encode("ascii")
         headers = request.headers.raw
-        if "Host" not in request.headers:
-            host = request.url.authority.encode("ascii")
-            headers = [(b"host", host)] + headers
         event = h11.Request(method=method, target=target, headers=headers)
         await self._send_event(event, timeout)
 
index 9bd35eaf2dab6a5a3cf8b7af89353aabda7c7f46..56c7728f63e3a41c7a9d4cdb25f4f02b532bdadc 100644 (file)
@@ -76,7 +76,7 @@ class HTTP2Connection:
             (b":authority", request.url.authority.encode("ascii")),
             (b":scheme", request.url.scheme.encode("ascii")),
             (b":path", request.url.full_path.encode("ascii")),
-        ] + request.headers.raw
+        ] + [(k, v) for k, v in request.headers.raw if k != b"host"]
         self.h2_state.send_headers(stream_id, headers)
         data_to_send = self.h2_state.data_to_send()
         await self.writer.write(data_to_send, timeout)
index 01c1aae13db8ed62cb9c448b8e7f41d005a31768..f0bf5311f6cc3055c7d3551d4e6d1dc4b4ddf3ac 100644 (file)
@@ -35,10 +35,12 @@ class WSGIDispatch(Dispatcher):
     def __init__(
         self,
         app: typing.Callable,
+        raise_app_exceptions: bool = True,
         script_name: str = "",
         remote_addr: str = "127.0.0.1",
     ) -> None:
         self.app = app
+        self.raise_app_exceptions = raise_app_exceptions
         self.script_name = script_name
         self.remote_addr = remote_addr
 
@@ -87,7 +89,7 @@ class WSGIDispatch(Dispatcher):
 
         assert seen_status is not None
         assert seen_response_headers is not None
-        if seen_exc_info:
+        if seen_exc_info and self.raise_app_exceptions:
             raise seen_exc_info[1]
 
         return Response(
index 02fc0ddb6a279f72dd8ef1ff0c28853a53adda88..1b5bb152e17fca6a4529f5c0a4d7109440cadceb 100644 (file)
@@ -51,9 +51,9 @@ AuthTypes = typing.Union[
     typing.Callable[["AsyncRequest"], "AsyncRequest"],
 ]
 
-AsyncRequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
+AsyncRequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
 
-RequestData = typing.Union[dict, bytes, typing.Iterator[bytes]]
+RequestData = typing.Union[dict, str, bytes, typing.Iterator[bytes]]
 
 RequestFiles = typing.Dict[
     str,
@@ -527,6 +527,7 @@ class BaseRequest:
 
         auto_headers = []  # type: typing.List[typing.Tuple[bytes, bytes]]
 
+        has_host = "host" in self.headers
         has_user_agent = "user-agent" in self.headers
         has_accept = "accept" in self.headers
         has_content_length = (
@@ -534,6 +535,8 @@ class BaseRequest:
         )
         has_accept_encoding = "accept-encoding" in self.headers
 
+        if not has_host:
+            auto_headers.append((b"host", self.url.authority.encode("ascii")))
         if not has_user_agent:
             auto_headers.append((b"user-agent", b"http3"))
         if not has_accept:
@@ -585,7 +588,8 @@ class AsyncRequest(BaseRequest):
             self.content = content
             if content_type:
                 self.headers["Content-Type"] = content_type
-        elif isinstance(data, bytes):
+        elif isinstance(data, (str, bytes)):
+            data = data.encode("utf-8") if isinstance(data, str) else data
             self.is_streaming = False
             self.content = data
         else:
@@ -634,7 +638,8 @@ class Request(BaseRequest):
             self.content = content
             if content_type:
                 self.headers["Content-Type"] = content_type
-        elif isinstance(data, bytes):
+        elif isinstance(data, (str, bytes)):
+            data = data.encode("utf-8") if isinstance(data, str) else data
             self.is_streaming = False
             self.content = data
         else:
index 07be1f13cc7e5e60cb394c891dd0fadfebd80a30..53fb68fd66b6b94e260676d4c13e4ae2ac91a19d 100644 (file)
@@ -3,7 +3,7 @@ import mimetypes
 import os
 import typing
 from io import BytesIO
-from urllib.parse import quote_plus
+from urllib.parse import quote
 
 
 class Field:
@@ -20,13 +20,13 @@ class DataField(Field):
         self.value = value
 
     def render_headers(self) -> bytes:
-        name = quote_plus(self.name, encoding="utf-8").encode("ascii")
+        name = quote(self.name, encoding="utf-8").encode("ascii")
         return b"".join(
             [b'Content-Disposition: form-data; name="', name, b'"\r\n' b"\r\n"]
         )
 
     def render_data(self) -> bytes:
-        return quote_plus(self.value, encoding="utf-8").encode("ascii")
+        return self.value.encode("utf-8")
 
 
 class FileField(Field):
@@ -49,8 +49,8 @@ class FileField(Field):
         return mimetypes.guess_type(self.filename)[0] or "application/octet-stream"
 
     def render_headers(self) -> bytes:
-        name = quote_plus(self.name, encoding="utf-8").encode("ascii")
-        filename = quote_plus(self.filename, encoding="utf-8").encode("ascii")
+        name = quote(self.name, encoding="utf-8").encode("ascii")
+        filename = quote(self.filename, encoding="utf-8").encode("ascii")
         content_type = self.content_type.encode("ascii")
         return b"".join(
             [