]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:sparkles: Add better support for request body access/manipulation with custom classe...
authordmontagu <35119617+dmontagu@users.noreply.github.com>
Sat, 5 Oct 2019 00:23:35 +0000 (17:23 -0700)
committerSebastián Ramírez <tiangolo@gmail.com>
Sat, 5 Oct 2019 00:23:34 +0000 (19:23 -0500)
docs/src/custom_request_and_route/tutorial001.py [new file with mode: 0644]
docs/src/custom_request_and_route/tutorial002.py [new file with mode: 0644]
docs/src/custom_request_and_route/tutorial003.py [new file with mode: 0644]
docs/tutorial/custom-request-and-route.md [new file with mode: 0644]
fastapi/routing.py
mkdocs.yml
tests/test_tutorial/test_custom_request_and_route/__init__.py [new file with mode: 0644]
tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py [new file with mode: 0644]
tests/test_tutorial/test_custom_request_and_route/test_tutorial002.py [new file with mode: 0644]
tests/test_tutorial/test_custom_request_and_route/test_tutorial003.py [new file with mode: 0644]

diff --git a/docs/src/custom_request_and_route/tutorial001.py b/docs/src/custom_request_and_route/tutorial001.py
new file mode 100644 (file)
index 0000000..cd21582
--- /dev/null
@@ -0,0 +1,37 @@
+import gzip
+from typing import Callable, List
+
+from fastapi import Body, FastAPI
+from fastapi.routing import APIRoute
+from starlette.requests import Request
+from starlette.responses import Response
+
+
+class GzipRequest(Request):
+    async def body(self) -> bytes:
+        if not hasattr(self, "_body"):
+            body = await super().body()
+            if "gzip" in self.headers.getlist("Content-Encoding"):
+                body = gzip.decompress(body)
+            self._body = body
+        return self._body
+
+
+class GzipRoute(APIRoute):
+    def get_route_handler(self) -> Callable:
+        original_route_handler = super().get_route_handler()
+
+        async def custom_route_handler(request: Request) -> Response:
+            request = GzipRequest(request.scope, request.receive)
+            return await original_route_handler(request)
+
+        return custom_route_handler
+
+
+app = FastAPI()
+app.router.route_class = GzipRoute
+
+
+@app.post("/sum")
+async def sum_numbers(numbers: List[int] = Body(...)):
+    return {"sum": sum(numbers)}
diff --git a/docs/src/custom_request_and_route/tutorial002.py b/docs/src/custom_request_and_route/tutorial002.py
new file mode 100644 (file)
index 0000000..95cad99
--- /dev/null
@@ -0,0 +1,31 @@
+from typing import Callable, List
+
+from fastapi import Body, FastAPI, HTTPException
+from fastapi.exceptions import RequestValidationError
+from fastapi.routing import APIRoute
+from starlette.requests import Request
+from starlette.responses import Response
+
+
+class ValidationErrorLoggingRoute(APIRoute):
+    def get_route_handler(self) -> Callable:
+        original_route_handler = super().get_route_handler()
+
+        async def custom_route_handler(request: Request) -> Response:
+            try:
+                return await original_route_handler(request)
+            except RequestValidationError as exc:
+                body = await request.body()
+                detail = {"errors": exc.errors(), "body": body.decode()}
+                raise HTTPException(status_code=422, detail=detail)
+
+        return custom_route_handler
+
+
+app = FastAPI()
+app.router.route_class = ValidationErrorLoggingRoute
+
+
+@app.post("/")
+async def sum_numbers(numbers: List[int] = Body(...)):
+    return sum(numbers)
diff --git a/docs/src/custom_request_and_route/tutorial003.py b/docs/src/custom_request_and_route/tutorial003.py
new file mode 100644 (file)
index 0000000..4497736
--- /dev/null
@@ -0,0 +1,41 @@
+import time
+from typing import Callable
+
+from fastapi import APIRouter, FastAPI
+from fastapi.routing import APIRoute
+from starlette.requests import Request
+from starlette.responses import Response
+
+
+class TimedRoute(APIRoute):
+    def get_route_handler(self) -> Callable:
+        original_route_handler = super().get_route_handler()
+
+        async def custom_route_handler(request: Request) -> Response:
+            before = time.time()
+            response: Response = await original_route_handler(request)
+            duration = time.time() - before
+            response.headers["X-Response-Time"] = str(duration)
+            print(f"route duration: {duration}")
+            print(f"route response: {response}")
+            print(f"route response headers: {response.headers}")
+            return response
+
+        return custom_route_handler
+
+
+app = FastAPI()
+router = APIRouter(route_class=TimedRoute)
+
+
+@app.get("/")
+async def not_timed():
+    return {"message": "Not timed"}
+
+
+@router.get("/timed")
+async def timed():
+    return {"message": "It's the time of my life"}
+
+
+app.include_router(router)
diff --git a/docs/tutorial/custom-request-and-route.md b/docs/tutorial/custom-request-and-route.md
new file mode 100644 (file)
index 0000000..49cca99
--- /dev/null
@@ -0,0 +1,100 @@
+In some cases, you may want to override the logic used by the `Request` and `APIRoute` classes.
+
+In particular, this may be a good alternative to logic in a middleware.
+
+For example, if you want to read or manipulate the request body before it is processed by your application.
+
+!!! danger
+    This is an "advanced" feature.
+
+    If you are just starting with **FastAPI** you might want to skip this section.
+
+## Use cases
+
+Some use cases include:
+
+* Converting non-JSON request bodies to JSON (e.g. [`msgpack`](https://msgpack.org/index.html)).
+* Decompressing gzip-compressed request bodies.
+* Automatically logging all request bodies.
+* Accessing the request body in an exception handler.
+
+## Handling custom request body encodings
+
+Let's see how to make use of a custom `Request` subclass to decompress gzip requests.
+
+And an `APIRoute` subclass to use that custom request class.
+
+### Create a custom `GzipRequest` class
+
+First, we create a `GzipRequest` class, which will overwrite the `Request.body()` method to decompress the body in the presence of an appropriate header.
+
+If there's no `gzip` in the header, it will not try to decompress the body.
+
+That way, the same route class can handle gzip compressed or uncompressed requests.
+
+```Python hl_lines="10 11 12 13 14 15 16 17"
+{!./src/custom_request_and_route/tutorial001.py!}
+```
+
+### Create a custom `GzipRoute` class
+
+Next, we create a custom subclass of `fastapi.routing.APIRoute` that will make use of the `GzipRequest`.
+
+This time, it will overwrite the method `APIRoute.get_route_handler()`.
+
+This method returns a function. And that function is what will receive a request and return a response.
+
+Here we use it to create a `GzipRequest` from the original request.
+
+```Python hl_lines="20 21 22 23 24 25 26 27 28"
+{!./src/custom_request_and_route/tutorial001.py!}
+```
+
+!!! note "Technical Details"
+    A `Request` has a `request.scope` attribute, that's just a Python `dict` containing the metadata related to the request.
+
+    A `Request` also has a `request.receive`, that's a function to "receive" the body of the request.
+
+    The `scope` `dict` and `receive` function are both part of the ASGI specification.
+
+    And those two things, `scope` and `receive`, are what is needed to create a new `Request` instance.
+
+    To learn more about the `Request` check <a href="https://www.starlette.io/requests/" target="_blank">Starlette's docs about Requests</a>.
+
+The only thing the function returned by `GzipRequest.get_route_handler` does differently is convert the `Request` to a `GzipRequest`.
+
+Doing this, our `GzipRequest` will take care of decompressing the data (if necessary) before passing it to our *path operations*.
+
+After that, all of the processing logic is the same.
+
+But because of our changes in `GzipRequest.body`, the request body will be automatically decompressed when it is loaded by **FastAPI** when needed.
+
+## Accessing the request body in an exception handler
+
+We can also use this same approach to access the request body in an exception handler.
+
+All we need to do is handle the request inside a `try`/`except` block:
+
+```Python hl_lines="15 17"
+{!./src/custom_request_and_route/tutorial002.py!}
+```
+
+If an exception occurs, the`Request` instance will still be in scope, so we can read and make use of the request body when handling the error:
+
+```Python hl_lines="18 19 20"
+{!./src/custom_request_and_route/tutorial002.py!}
+```
+
+## Custom `APIRoute` class in a router
+
+You can also set the `route_class` parameter of an `APIRouter`:
+
+```Python hl_lines="25"
+{!./src/custom_request_and_route/tutorial003.py!}
+```
+
+In this example, the *path operations* under the `router` will use the custom `TimedRoute` class, and will have an extra `X-Response-Time` header in the response with the time it took to generate the response:
+
+```Python hl_lines="15 16 17 18 19"
+{!./src/custom_request_and_route/tutorial003.py!}
+```
index b0902310c7c9dadabbcde3fb0e79ee4209e78a2d..2a4e0bc8d4f2e665acaf0c6428b68d21e3e4becb 100644 (file)
@@ -65,7 +65,7 @@ def serialize_response(
         return jsonable_encoder(response)
 
 
-def get_app(
+def get_request_handler(
     dependant: Dependant,
     body_field: Field = None,
     status_code: int = 200,
@@ -294,19 +294,20 @@ class APIRoute(routing.Route):
             )
         self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
         self.dependency_overrides_provider = dependency_overrides_provider
-        self.app = request_response(
-            get_app(
-                dependant=self.dependant,
-                body_field=self.body_field,
-                status_code=self.status_code,
-                response_class=self.response_class or JSONResponse,
-                response_field=self.secure_cloned_response_field,
-                response_model_include=self.response_model_include,
-                response_model_exclude=self.response_model_exclude,
-                response_model_by_alias=self.response_model_by_alias,
-                response_model_skip_defaults=self.response_model_skip_defaults,
-                dependency_overrides_provider=self.dependency_overrides_provider,
-            )
+        self.app = request_response(self.get_route_handler())
+
+    def get_route_handler(self) -> Callable:
+        return get_request_handler(
+            dependant=self.dependant,
+            body_field=self.body_field,
+            status_code=self.status_code,
+            response_class=self.response_class or JSONResponse,
+            response_field=self.secure_cloned_response_field,
+            response_model_include=self.response_model_include,
+            response_model_exclude=self.response_model_exclude,
+            response_model_by_alias=self.response_model_by_alias,
+            response_model_skip_defaults=self.response_model_skip_defaults,
+            dependency_overrides_provider=self.dependency_overrides_provider,
         )
 
 
index a61deedb073db7fb1abf8d4c08b243331cc57c18..b4c2ec15219493315b3d6f35dddbd400a12380f2 100644 (file)
@@ -81,6 +81,7 @@ nav:
         - GraphQL: 'tutorial/graphql.md'
         - WebSockets: 'tutorial/websockets.md'
         - 'Events: startup - shutdown': 'tutorial/events.md'
+        - Custom Request and APIRoute class: 'tutorial/custom-request-and-route.md'
         - Testing: 'tutorial/testing.md'
         - Testing Dependencies with Overrides: 'tutorial/testing-dependencies.md'
         - Debugging: 'tutorial/debugging.md'
diff --git a/tests/test_tutorial/test_custom_request_and_route/__init__.py b/tests/test_tutorial/test_custom_request_and_route/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py b/tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py
new file mode 100644 (file)
index 0000000..2b4b474
--- /dev/null
@@ -0,0 +1,34 @@
+import gzip
+import json
+
+import pytest
+from starlette.requests import Request
+from starlette.testclient import TestClient
+
+from custom_request_and_route.tutorial001 import app
+
+
+@app.get("/check-class")
+async def check_gzip_request(request: Request):
+    return {"request_class": type(request).__name__}
+
+
+client = TestClient(app)
+
+
+@pytest.mark.parametrize("compress", [True, False])
+def test_gzip_request(compress):
+    n = 1000
+    headers = {}
+    body = [1] * n
+    data = json.dumps(body).encode()
+    if compress:
+        data = gzip.compress(data)
+        headers["Content-Encoding"] = "gzip"
+    response = client.post("/sum", data=data, headers=headers)
+    assert response.json() == {"sum": n}
+
+
+def test_request_class():
+    response = client.get("/check-class")
+    assert response.json() == {"request_class": "GzipRequest"}
diff --git a/tests/test_tutorial/test_custom_request_and_route/test_tutorial002.py b/tests/test_tutorial/test_custom_request_and_route/test_tutorial002.py
new file mode 100644 (file)
index 0000000..a50760b
--- /dev/null
@@ -0,0 +1,27 @@
+from starlette.testclient import TestClient
+
+from custom_request_and_route.tutorial002 import app
+
+client = TestClient(app)
+
+
+def test_endpoint_works():
+    response = client.post("/", json=[1, 2, 3])
+    assert response.json() == 6
+
+
+def test_exception_handler_body_access():
+    response = client.post("/", json={"numbers": [1, 2, 3]})
+
+    assert response.json() == {
+        "detail": {
+            "body": '{"numbers": [1, 2, 3]}',
+            "errors": [
+                {
+                    "loc": ["body", "numbers"],
+                    "msg": "value is not a valid list",
+                    "type": "type_error.list",
+                }
+            ],
+        }
+    }
diff --git a/tests/test_tutorial/test_custom_request_and_route/test_tutorial003.py b/tests/test_tutorial/test_custom_request_and_route/test_tutorial003.py
new file mode 100644 (file)
index 0000000..bc4ccac
--- /dev/null
@@ -0,0 +1,18 @@
+from starlette.testclient import TestClient
+
+from custom_request_and_route.tutorial003 import app
+
+client = TestClient(app)
+
+
+def test_get():
+    response = client.get("/")
+    assert response.json() == {"message": "Not timed"}
+    assert "X-Response-Time" not in response.headers
+
+
+def test_get_timed():
+    response = client.get("/timed")
+    assert response.json() == {"message": "It's the time of my life"}
+    assert "X-Response-Time" in response.headers
+    assert float(response.headers["X-Response-Time"]) > 0