]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
🐛 Fix removing body from status codes that do not support it (#5145)
authorSebastián Ramírez <tiangolo@gmail.com>
Thu, 14 Jul 2022 11:19:42 +0000 (13:19 +0200)
committerGitHub <noreply@github.com>
Thu, 14 Jul 2022 11:19:42 +0000 (13:19 +0200)
fastapi/openapi/constants.py
fastapi/openapi/utils.py
fastapi/routing.py
fastapi/utils.py
tests/test_response_code_no_body.py

index 3e69e55244af6121f3385384a89e085cb4f16286..1897ad750915ebc9043cbc2b9421d7ad8a3bcb0d 100644 (file)
@@ -1,3 +1,2 @@
 METHODS_WITH_BODY = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"}
-STATUS_CODES_WITH_NO_BODY = {100, 101, 102, 103, 204, 304}
 REF_PREFIX = "#/components/schemas/"
index 4eb727bd4ffc2d291c8ac2be11999e3fee52c27b..5d3d95c2442cd091a6ce815eab547b68691c4faf 100644 (file)
@@ -9,11 +9,7 @@ from fastapi.datastructures import DefaultPlaceholder
 from fastapi.dependencies.models import Dependant
 from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
 from fastapi.encoders import jsonable_encoder
-from fastapi.openapi.constants import (
-    METHODS_WITH_BODY,
-    REF_PREFIX,
-    STATUS_CODES_WITH_NO_BODY,
-)
+from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX
 from fastapi.openapi.models import OpenAPI
 from fastapi.params import Body, Param
 from fastapi.responses import Response
@@ -21,6 +17,7 @@ from fastapi.utils import (
     deep_dict_update,
     generate_operation_id_for_path,
     get_model_definitions,
+    is_body_allowed_for_status_code,
 )
 from pydantic import BaseModel
 from pydantic.fields import ModelField, Undefined
@@ -265,9 +262,8 @@ def get_openapi_path(
             operation.setdefault("responses", {}).setdefault(status_code, {})[
                 "description"
             ] = route.response_description
-            if (
-                route_response_media_type
-                and route.status_code not in STATUS_CODES_WITH_NO_BODY
+            if route_response_media_type and is_body_allowed_for_status_code(
+                route.status_code
             ):
                 response_schema = {"type": "string"}
                 if lenient_issubclass(current_response_class, JSONResponse):
index a6542c15a035e816a8122e163696ff9a4c77d324..6f1a8e9000dd6cf7b8f9765a853ad12e527d08e8 100644 (file)
@@ -29,13 +29,13 @@ from fastapi.dependencies.utils import (
 )
 from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder
 from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
-from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY
 from fastapi.types import DecoratedCallable
 from fastapi.utils import (
     create_cloned_field,
     create_response_field,
     generate_unique_id,
     get_value_or_default,
+    is_body_allowed_for_status_code,
 )
 from pydantic import BaseModel
 from pydantic.error_wrappers import ErrorWrapper, ValidationError
@@ -232,7 +232,17 @@ def get_request_handler(
                 if raw_response.background is None:
                     raw_response.background = background_tasks
                 return raw_response
-            response_data = await serialize_response(
+            response_args: Dict[str, Any] = {"background": background_tasks}
+            # If status_code was set, use it, otherwise use the default from the
+            # response class, in the case of redirect it's 307
+            current_status_code = (
+                status_code if status_code else sub_response.status_code
+            )
+            if current_status_code is not None:
+                response_args["status_code"] = current_status_code
+            if sub_response.status_code:
+                response_args["status_code"] = sub_response.status_code
+            content = await serialize_response(
                 field=response_field,
                 response_content=raw_response,
                 include=response_model_include,
@@ -243,15 +253,10 @@ def get_request_handler(
                 exclude_none=response_model_exclude_none,
                 is_coroutine=is_coroutine,
             )
-            response_args: Dict[str, Any] = {"background": background_tasks}
-            # If status_code was set, use it, otherwise use the default from the
-            # response class, in the case of redirect it's 307
-            if status_code is not None:
-                response_args["status_code"] = status_code
-            response = actual_response_class(response_data, **response_args)
+            response = actual_response_class(content, **response_args)
+            if not is_body_allowed_for_status_code(status_code):
+                response.body = b""
             response.headers.raw.extend(sub_response.headers.raw)
-            if sub_response.status_code:
-                response.status_code = sub_response.status_code
             return response
 
     return app
@@ -377,8 +382,8 @@ class APIRoute(routing.Route):
             status_code = int(status_code)
         self.status_code = status_code
         if self.response_model:
-            assert (
-                status_code not in STATUS_CODES_WITH_NO_BODY
+            assert is_body_allowed_for_status_code(
+                status_code
             ), f"Status code {status_code} must not have a response body"
             response_name = "Response_" + self.unique_id
             self.response_field = create_response_field(
@@ -410,8 +415,8 @@ class APIRoute(routing.Route):
             assert isinstance(response, dict), "An additional response must be a dict"
             model = response.get("model")
             if model:
-                assert (
-                    additional_status_code not in STATUS_CODES_WITH_NO_BODY
+                assert is_body_allowed_for_status_code(
+                    additional_status_code
                 ), f"Status code {additional_status_code} must not have a response body"
                 response_name = f"Response_{additional_status_code}_{self.unique_id}"
                 response_field = create_response_field(name=response_name, type_=model)
index a7e135bcab598219cc32740796669ee1dee4f7ce..887d57c90258a80a1b7444318b8cb0c32fb1f192 100644 (file)
@@ -18,6 +18,13 @@ if TYPE_CHECKING:  # pragma: nocover
     from .routing import APIRoute
 
 
+def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
+    if status_code is None:
+        return True
+    current_status_code = int(status_code)
+    return not (current_status_code < 200 or current_status_code in {204, 304})
+
+
 def get_model_definitions(
     *,
     flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
index 45e2fabc7ef79cff3e5498d1c7dcab2331e054a4..6d9b5c333401b65ae6421af9ab37fbc34dc3e5d5 100644 (file)
@@ -28,7 +28,7 @@ class JsonApiError(BaseModel):
     responses={500: {"description": "Error", "model": JsonApiError}},
 )
 async def a():
-    pass  # pragma: no cover
+    pass
 
 
 @app.get("/b", responses={204: {"description": "No Content"}})
@@ -106,3 +106,10 @@ def test_openapi_schema():
     response = client.get("/openapi.json")
     assert response.status_code == 200, response.text
     assert response.json() == openapi_schema
+
+
+def test_get_response():
+    response = client.get("/a")
+    assert response.status_code == 204, response.text
+    assert "content-length" not in response.headers
+    assert response.content == b""