]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
🐛 Make sure a parameter defined as required is kept required in OpenAPI even if defin...
authorCharlie DiGiovanna <cd17822@gmail.com>
Sat, 3 Sep 2022 17:12:41 +0000 (13:12 -0400)
committerGitHub <noreply@github.com>
Sat, 3 Sep 2022 17:12:41 +0000 (17:12 +0000)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
fastapi/openapi/utils.py
tests/test_enforce_once_required_parameter.py [new file with mode: 0644]

index 4d5741f30de1c45f90244ee70aa307cf274fac92..86e15b46d30a35fddbc92f2dd993497c835f5e33 100644 (file)
@@ -222,11 +222,18 @@ def get_openapi_path(
             )
             parameters.extend(operation_parameters)
             if parameters:
-                operation["parameters"] = list(
-                    {
-                        (param["in"], param["name"]): param for param in parameters
-                    }.values()
-                )
+                all_parameters = {
+                    (param["in"], param["name"]): param for param in parameters
+                }
+                required_parameters = {
+                    (param["in"], param["name"]): param
+                    for param in parameters
+                    if param.get("required")
+                }
+                # Make sure required definitions of the same parameter take precedence
+                # over non-required definitions
+                all_parameters.update(required_parameters)
+                operation["parameters"] = list(all_parameters.values())
             if method in METHODS_WITH_BODY:
                 request_body_oai = get_openapi_operation_request_body(
                     body_field=route.body_field, model_name_map=model_name_map
diff --git a/tests/test_enforce_once_required_parameter.py b/tests/test_enforce_once_required_parameter.py
new file mode 100644 (file)
index 0000000..ba8c735
--- /dev/null
@@ -0,0 +1,111 @@
+from typing import Optional
+
+from fastapi import Depends, FastAPI, Query, status
+from fastapi.testclient import TestClient
+
+app = FastAPI()
+
+
+def _get_client_key(client_id: str = Query(...)) -> str:
+    return f"{client_id}_key"
+
+
+def _get_client_tag(client_id: Optional[str] = Query(None)) -> Optional[str]:
+    if client_id is None:
+        return None
+    return f"{client_id}_tag"
+
+
+@app.get("/foo")
+def foo_handler(
+    client_key: str = Depends(_get_client_key),
+    client_tag: Optional[str] = Depends(_get_client_tag),
+):
+    return {"client_id": client_key, "client_tag": client_tag}
+
+
+client = TestClient(app)
+
+expected_schema = {
+    "components": {
+        "schemas": {
+            "HTTPValidationError": {
+                "properties": {
+                    "detail": {
+                        "items": {"$ref": "#/components/schemas/ValidationError"},
+                        "title": "Detail",
+                        "type": "array",
+                    }
+                },
+                "title": "HTTPValidationError",
+                "type": "object",
+            },
+            "ValidationError": {
+                "properties": {
+                    "loc": {
+                        "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
+                        "title": "Location",
+                        "type": "array",
+                    },
+                    "msg": {"title": "Message", "type": "string"},
+                    "type": {"title": "Error " "Type", "type": "string"},
+                },
+                "required": ["loc", "msg", "type"],
+                "title": "ValidationError",
+                "type": "object",
+            },
+        }
+    },
+    "info": {"title": "FastAPI", "version": "0.1.0"},
+    "openapi": "3.0.2",
+    "paths": {
+        "/foo": {
+            "get": {
+                "operationId": "foo_handler_foo_get",
+                "parameters": [
+                    {
+                        "in": "query",
+                        "name": "client_id",
+                        "required": True,
+                        "schema": {"title": "Client Id", "type": "string"},
+                    },
+                ],
+                "responses": {
+                    "200": {
+                        "content": {"application/json": {"schema": {}}},
+                        "description": "Successful " "Response",
+                    },
+                    "422": {
+                        "content": {
+                            "application/json": {
+                                "schema": {
+                                    "$ref": "#/components/schemas/HTTPValidationError"
+                                }
+                            }
+                        },
+                        "description": "Validation " "Error",
+                    },
+                },
+                "summary": "Foo Handler",
+            }
+        }
+    },
+}
+
+
+def test_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == status.HTTP_200_OK
+    actual_schema = response.json()
+    assert actual_schema == expected_schema
+
+
+def test_get_invalid():
+    response = client.get("/foo", params={"client_id": None})
+    assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
+
+
+def test_get_valid():
+    response = client.get("/foo", params={"client_id": "bar"})
+    assert response.status_code == 200
+    assert response.json() == {"client_id": "bar_key", "client_tag": "bar_tag"}