]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
Add HEAD method on GET routes
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Thu, 20 Mar 2025 10:16:14 +0000 (10:16 +0000)
committerMarcelo Trylesinski <marcelotryle@gmail.com>
Thu, 20 Mar 2025 10:16:14 +0000 (10:16 +0000)
fastapi/openapi/utils.py
fastapi/routing.py
tests/test_extra_routes.py
tests/test_tutorial/test_first_steps/test_tutorial001.py

index bd8f3c106acc42601c1f01bcc641ed57f80f74c5..a7f4cdecc10eddce3dd297cb23bd59e96abc608e 100644 (file)
@@ -253,6 +253,8 @@ def get_openapi_path(
     route_response_media_type: Optional[str] = current_response_class.media_type
     if route.include_in_schema:
         for method in route.methods:
+            if method == "HEAD" and "GET" in route.methods:
+                continue
             operation = get_openapi_operation_metadata(
                 route=route, method=method, operation_ids=operation_ids
             )
index 457481e3258bffef81125f0ccc3caa461a9c2c9b..b454070921f50789f3b5c7901e75d5205e36b59d 100644 (file)
@@ -490,7 +490,9 @@ class APIRoute(routing.Route):
         self.name = get_name(endpoint) if name is None else name
         self.path_regex, self.path_format, self.param_convertors = compile_path(path)
         if methods is None:
-            methods = ["GET"]
+            methods = ["GET", "HEAD"]
+        elif "GET" in methods:
+            methods = set(methods) | {"HEAD"}
         self.methods: Set[str] = {method.upper() for method in methods}
         if isinstance(generate_unique_id_function, DefaultPlaceholder):
             current_generate_unique_id: Callable[[APIRoute], str] = (
@@ -1724,7 +1726,7 @@ class APIRouter(routing.Router):
             response_description=response_description,
             responses=responses,
             deprecated=deprecated,
-            methods=["GET"],
+            methods=["GET", "HEAD"],
             operation_id=operation_id,
             response_model_include=response_model_include,
             response_model_exclude=response_model_exclude,
index bd16fe9254cc341e0f0907676a0033694abc33e4..b5b5c65002bdef373a65ba7239056162da1d5655 100644 (file)
@@ -14,6 +14,11 @@ class Item(BaseModel):
     price: Optional[float] = None
 
 
+@app.head("/items/{item_id}")
+def head_item(item_id: str):
+    return JSONResponse(None, headers={"x-fastapi-item-id": item_id})
+
+
 @app.api_route("/items/{item_id}", methods=["GET"])
 def get_items(item_id: str):
     return {"item_id": item_id}
@@ -31,11 +36,6 @@ def delete_item(item_id: str, item: Item):
     return {"item_id": item_id, "item": item}
 
 
-@app.head("/items/{item_id}")
-def head_item(item_id: str):
-    return JSONResponse(None, headers={"x-fastapi-item-id": item_id})
-
-
 @app.options("/items/{item_id}")
 def options_item(item_id: str):
     return JSONResponse(None, headers={"x-fastapi-item-id": item_id})
index 6cc9fc22826dd5e163ea35417cb6d40866de877a..b9692bb2c4406210178b40a340126b0adf825d13 100644 (file)
@@ -13,11 +13,14 @@ client = TestClient(app)
         ("/nonexistent", 404, {"detail": "Not Found"}),
     ],
 )
-def test_get_path(path, expected_status, expected_response):
+def test_get_path(path: str, expected_status: int, expected_response: dict[str, str]):
     response = client.get(path)
     assert response.status_code == expected_status
     assert response.json() == expected_response
 
+    response = client.head(path)
+    assert response.status_code == expected_status
+
 
 def test_openapi_schema():
     response = client.get("/openapi.json")