]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
🐛 Fix `convert_underscores=False` for header Pydantic models (#13515)
authorSebastián Ramírez <tiangolo@gmail.com>
Sun, 23 Mar 2025 20:48:54 +0000 (20:48 +0000)
committerGitHub <noreply@github.com>
Sun, 23 Mar 2025 20:48:54 +0000 (21:48 +0100)
12 files changed:
docs/en/docs/tutorial/header-param-models.md
docs_src/header_param_models/tutorial003.py [new file with mode: 0644]
docs_src/header_param_models/tutorial003_an.py [new file with mode: 0644]
docs_src/header_param_models/tutorial003_an_py310.py [new file with mode: 0644]
docs_src/header_param_models/tutorial003_an_py39.py [new file with mode: 0644]
docs_src/header_param_models/tutorial003_py310.py [new file with mode: 0644]
docs_src/header_param_models/tutorial003_py39.py [new file with mode: 0644]
fastapi/dependencies/utils.py
fastapi/openapi/utils.py
tests/test_tutorial/test_header_param_models/test_tutorial001.py
tests/test_tutorial/test_header_param_models/test_tutorial002.py
tests/test_tutorial/test_header_param_models/test_tutorial003.py [new file with mode: 0644]

index 73950a668074e7a80e523efa4c4f412996462a7c..4cdf097057a15f6adb755dee5444a44e91b0b95d 100644 (file)
@@ -51,6 +51,22 @@ For example, if the client tries to send a `tool` header with a value of `plumbu
 }
 ```
 
+## Disable Convert Underscores
+
+The same way as with regular header parameters, when you have underscore characters in the parameter names, they are **automatically converted to hyphens**.
+
+For example, if you have a header parameter `save_data` in the code, the expected HTTP header will be `save-data`, and it will show up like that in the docs.
+
+If for some reason you need to disable this automatic conversion, you can do it as well for Pydantic models for header parameters.
+
+{* ../../docs_src/header_param_models/tutorial003_an_py310.py hl[19] *}
+
+/// warning
+
+Before setting `convert_underscores` to `False`, bear in mind that some HTTP proxies and servers disallow the usage of headers with underscores.
+
+///
+
 ## Summary
 
 You can use **Pydantic models** to declare **headers** in **FastAPI**. 😎
diff --git a/docs_src/header_param_models/tutorial003.py b/docs_src/header_param_models/tutorial003.py
new file mode 100644 (file)
index 0000000..dc2eb74
--- /dev/null
@@ -0,0 +1,19 @@
+from typing import List, Union
+
+from fastapi import FastAPI, Header
+from pydantic import BaseModel
+
+app = FastAPI()
+
+
+class CommonHeaders(BaseModel):
+    host: str
+    save_data: bool
+    if_modified_since: Union[str, None] = None
+    traceparent: Union[str, None] = None
+    x_tag: List[str] = []
+
+
+@app.get("/items/")
+async def read_items(headers: CommonHeaders = Header(convert_underscores=False)):
+    return headers
diff --git a/docs_src/header_param_models/tutorial003_an.py b/docs_src/header_param_models/tutorial003_an.py
new file mode 100644 (file)
index 0000000..e3edb11
--- /dev/null
@@ -0,0 +1,22 @@
+from typing import List, Union
+
+from fastapi import FastAPI, Header
+from pydantic import BaseModel
+from typing_extensions import Annotated
+
+app = FastAPI()
+
+
+class CommonHeaders(BaseModel):
+    host: str
+    save_data: bool
+    if_modified_since: Union[str, None] = None
+    traceparent: Union[str, None] = None
+    x_tag: List[str] = []
+
+
+@app.get("/items/")
+async def read_items(
+    headers: Annotated[CommonHeaders, Header(convert_underscores=False)],
+):
+    return headers
diff --git a/docs_src/header_param_models/tutorial003_an_py310.py b/docs_src/header_param_models/tutorial003_an_py310.py
new file mode 100644 (file)
index 0000000..07bfa83
--- /dev/null
@@ -0,0 +1,21 @@
+from typing import Annotated
+
+from fastapi import FastAPI, Header
+from pydantic import BaseModel
+
+app = FastAPI()
+
+
+class CommonHeaders(BaseModel):
+    host: str
+    save_data: bool
+    if_modified_since: str | None = None
+    traceparent: str | None = None
+    x_tag: list[str] = []
+
+
+@app.get("/items/")
+async def read_items(
+    headers: Annotated[CommonHeaders, Header(convert_underscores=False)],
+):
+    return headers
diff --git a/docs_src/header_param_models/tutorial003_an_py39.py b/docs_src/header_param_models/tutorial003_an_py39.py
new file mode 100644 (file)
index 0000000..8be6b01
--- /dev/null
@@ -0,0 +1,21 @@
+from typing import Annotated, Union
+
+from fastapi import FastAPI, Header
+from pydantic import BaseModel
+
+app = FastAPI()
+
+
+class CommonHeaders(BaseModel):
+    host: str
+    save_data: bool
+    if_modified_since: Union[str, None] = None
+    traceparent: Union[str, None] = None
+    x_tag: list[str] = []
+
+
+@app.get("/items/")
+async def read_items(
+    headers: Annotated[CommonHeaders, Header(convert_underscores=False)],
+):
+    return headers
diff --git a/docs_src/header_param_models/tutorial003_py310.py b/docs_src/header_param_models/tutorial003_py310.py
new file mode 100644 (file)
index 0000000..65e92a2
--- /dev/null
@@ -0,0 +1,17 @@
+from fastapi import FastAPI, Header
+from pydantic import BaseModel
+
+app = FastAPI()
+
+
+class CommonHeaders(BaseModel):
+    host: str
+    save_data: bool
+    if_modified_since: str | None = None
+    traceparent: str | None = None
+    x_tag: list[str] = []
+
+
+@app.get("/items/")
+async def read_items(headers: CommonHeaders = Header(convert_underscores=False)):
+    return headers
diff --git a/docs_src/header_param_models/tutorial003_py39.py b/docs_src/header_param_models/tutorial003_py39.py
new file mode 100644 (file)
index 0000000..848c341
--- /dev/null
@@ -0,0 +1,19 @@
+from typing import Union
+
+from fastapi import FastAPI, Header
+from pydantic import BaseModel
+
+app = FastAPI()
+
+
+class CommonHeaders(BaseModel):
+    host: str
+    save_data: bool
+    if_modified_since: Union[str, None] = None
+    traceparent: Union[str, None] = None
+    x_tag: list[str] = []
+
+
+@app.get("/items/")
+async def read_items(headers: CommonHeaders = Header(convert_underscores=False)):
+    return headers
index d205d17fac274132c7e8339f61ed1970f80ae49f..84dfa4d0306a6a885004cbc5cc6f5a835fd46955 100644 (file)
@@ -750,9 +750,15 @@ def request_params_to_args(
     first_field = fields[0]
     fields_to_extract = fields
     single_not_embedded_field = False
+    default_convert_underscores = True
     if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
         fields_to_extract = get_cached_model_fields(first_field.type_)
         single_not_embedded_field = True
+        # If headers are in a Pydantic model, the way to disable convert_underscores
+        # would be with Header(convert_underscores=False) at the Pydantic model level
+        default_convert_underscores = getattr(
+            first_field.field_info, "convert_underscores", True
+        )
 
     params_to_process: Dict[str, Any] = {}
 
@@ -763,7 +769,9 @@ def request_params_to_args(
         if isinstance(received_params, Headers):
             # Handle fields extracted from a Pydantic Model for a header, each field
             # doesn't have a FieldInfo of type Header with the default convert_underscores=True
-            convert_underscores = getattr(field.field_info, "convert_underscores", True)
+            convert_underscores = getattr(
+                field.field_info, "convert_underscores", default_convert_underscores
+            )
             if convert_underscores:
                 alias = (
                     field.alias
index bd8f3c106acc42601c1f01bcc641ed57f80f74c5..808646cc27216c3423681041ee9923372a2a9d03 100644 (file)
@@ -32,6 +32,7 @@ from fastapi.utils import (
     generate_operation_id_for_path,
     is_body_allowed_for_status_code,
 )
+from pydantic import BaseModel
 from starlette.responses import JSONResponse
 from starlette.routing import BaseRoute
 from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
@@ -113,6 +114,13 @@ def _get_openapi_operation_parameters(
         (ParamTypes.header, header_params),
         (ParamTypes.cookie, cookie_params),
     ]
+    default_convert_underscores = True
+    if len(flat_dependant.header_params) == 1:
+        first_field = flat_dependant.header_params[0]
+        if lenient_issubclass(first_field.type_, BaseModel):
+            default_convert_underscores = getattr(
+                first_field.field_info, "convert_underscores", True
+            )
     for param_type, param_group in parameter_groups:
         for param in param_group:
             field_info = param.field_info
@@ -126,8 +134,21 @@ def _get_openapi_operation_parameters(
                 field_mapping=field_mapping,
                 separate_input_output_schemas=separate_input_output_schemas,
             )
+            name = param.alias
+            convert_underscores = getattr(
+                param.field_info,
+                "convert_underscores",
+                default_convert_underscores,
+            )
+            if (
+                param_type == ParamTypes.header
+                and param.alias == param.name
+                and convert_underscores
+            ):
+                name = param.name.replace("_", "-")
+
             parameter = {
-                "name": param.alias,
+                "name": name,
                 "in": param_type.value,
                 "required": param.required,
                 "schema": param_schema,
index 06b2404cf0a3dec409ca910d633b9d752e48348f..bc876897b21f7190ad696d7e79f80a9cf78eacf9 100644 (file)
@@ -129,13 +129,13 @@ def test_openapi_schema(client: TestClient):
                                 "schema": {"type": "string", "title": "Host"},
                             },
                             {
-                                "name": "save_data",
+                                "name": "save-data",
                                 "in": "header",
                                 "required": True,
                                 "schema": {"type": "boolean", "title": "Save Data"},
                             },
                             {
-                                "name": "if_modified_since",
+                                "name": "if-modified-since",
                                 "in": "header",
                                 "required": False,
                                 "schema": IsDict(
@@ -171,7 +171,7 @@ def test_openapi_schema(client: TestClient):
                                 ),
                             },
                             {
-                                "name": "x_tag",
+                                "name": "x-tag",
                                 "in": "header",
                                 "required": False,
                                 "schema": {
index e07655a0c038daac2ea6ef9ff1164f032f7ca58f..0615521c43080e33a9ecdb68ea0d8b7c5b248a45 100644 (file)
@@ -140,13 +140,13 @@ def test_openapi_schema(client: TestClient):
                                 "schema": {"type": "string", "title": "Host"},
                             },
                             {
-                                "name": "save_data",
+                                "name": "save-data",
                                 "in": "header",
                                 "required": True,
                                 "schema": {"type": "boolean", "title": "Save Data"},
                             },
                             {
-                                "name": "if_modified_since",
+                                "name": "if-modified-since",
                                 "in": "header",
                                 "required": False,
                                 "schema": IsDict(
@@ -182,7 +182,7 @@ def test_openapi_schema(client: TestClient):
                                 ),
                             },
                             {
-                                "name": "x_tag",
+                                "name": "x-tag",
                                 "in": "header",
                                 "required": False,
                                 "schema": {
diff --git a/tests/test_tutorial/test_header_param_models/test_tutorial003.py b/tests/test_tutorial/test_header_param_models/test_tutorial003.py
new file mode 100644 (file)
index 0000000..60940e1
--- /dev/null
@@ -0,0 +1,285 @@
+import importlib
+
+import pytest
+from dirty_equals import IsDict
+from fastapi.testclient import TestClient
+from inline_snapshot import snapshot
+
+from tests.utils import needs_py39, needs_py310
+
+
+@pytest.fixture(
+    name="client",
+    params=[
+        "tutorial003",
+        pytest.param("tutorial003_py39", marks=needs_py39),
+        pytest.param("tutorial003_py310", marks=needs_py310),
+        "tutorial003_an",
+        pytest.param("tutorial003_an_py39", marks=needs_py39),
+        pytest.param("tutorial003_an_py310", marks=needs_py310),
+    ],
+)
+def get_client(request: pytest.FixtureRequest):
+    mod = importlib.import_module(f"docs_src.header_param_models.{request.param}")
+
+    client = TestClient(mod.app)
+    return client
+
+
+def test_header_param_model(client: TestClient):
+    response = client.get(
+        "/items/",
+        headers=[
+            ("save_data", "true"),
+            ("if_modified_since", "yesterday"),
+            ("traceparent", "123"),
+            ("x_tag", "one"),
+            ("x_tag", "two"),
+        ],
+    )
+    assert response.status_code == 200
+    assert response.json() == {
+        "host": "testserver",
+        "save_data": True,
+        "if_modified_since": "yesterday",
+        "traceparent": "123",
+        "x_tag": ["one", "two"],
+    }
+
+
+def test_header_param_model_no_underscore(client: TestClient):
+    response = client.get(
+        "/items/",
+        headers=[
+            ("save-data", "true"),
+            ("if-modified-since", "yesterday"),
+            ("traceparent", "123"),
+            ("x-tag", "one"),
+            ("x-tag", "two"),
+        ],
+    )
+    assert response.status_code == 422
+    assert response.json() == snapshot(
+        {
+            "detail": [
+                IsDict(
+                    {
+                        "type": "missing",
+                        "loc": ["header", "save_data"],
+                        "msg": "Field required",
+                        "input": {
+                            "host": "testserver",
+                            "traceparent": "123",
+                            "x_tag": [],
+                            "accept": "*/*",
+                            "accept-encoding": "gzip, deflate",
+                            "connection": "keep-alive",
+                            "user-agent": "testclient",
+                            "save-data": "true",
+                            "if-modified-since": "yesterday",
+                            "x-tag": "two",
+                        },
+                    }
+                )
+                | IsDict(
+                    # TODO: remove when deprecating Pydantic v1
+                    {
+                        "type": "value_error.missing",
+                        "loc": ["header", "save_data"],
+                        "msg": "field required",
+                    }
+                )
+            ]
+        }
+    )
+
+
+def test_header_param_model_defaults(client: TestClient):
+    response = client.get("/items/", headers=[("save_data", "true")])
+    assert response.status_code == 200
+    assert response.json() == {
+        "host": "testserver",
+        "save_data": True,
+        "if_modified_since": None,
+        "traceparent": None,
+        "x_tag": [],
+    }
+
+
+def test_header_param_model_invalid(client: TestClient):
+    response = client.get("/items/")
+    assert response.status_code == 422
+    assert response.json() == snapshot(
+        {
+            "detail": [
+                IsDict(
+                    {
+                        "type": "missing",
+                        "loc": ["header", "save_data"],
+                        "msg": "Field required",
+                        "input": {
+                            "x_tag": [],
+                            "host": "testserver",
+                            "accept": "*/*",
+                            "accept-encoding": "gzip, deflate",
+                            "connection": "keep-alive",
+                            "user-agent": "testclient",
+                        },
+                    }
+                )
+                | IsDict(
+                    # TODO: remove when deprecating Pydantic v1
+                    {
+                        "type": "value_error.missing",
+                        "loc": ["header", "save_data"],
+                        "msg": "field required",
+                    }
+                )
+            ]
+        }
+    )
+
+
+def test_header_param_model_extra(client: TestClient):
+    response = client.get(
+        "/items/", headers=[("save_data", "true"), ("tool", "plumbus")]
+    )
+    assert response.status_code == 200, response.text
+    assert response.json() == snapshot(
+        {
+            "host": "testserver",
+            "save_data": True,
+            "if_modified_since": None,
+            "traceparent": None,
+            "x_tag": [],
+        }
+    )
+
+
+def test_openapi_schema(client: TestClient):
+    response = client.get("/openapi.json")
+    assert response.status_code == 200, response.text
+    assert response.json() == snapshot(
+        {
+            "openapi": "3.1.0",
+            "info": {"title": "FastAPI", "version": "0.1.0"},
+            "paths": {
+                "/items/": {
+                    "get": {
+                        "summary": "Read Items",
+                        "operationId": "read_items_items__get",
+                        "parameters": [
+                            {
+                                "name": "host",
+                                "in": "header",
+                                "required": True,
+                                "schema": {"type": "string", "title": "Host"},
+                            },
+                            {
+                                "name": "save_data",
+                                "in": "header",
+                                "required": True,
+                                "schema": {"type": "boolean", "title": "Save Data"},
+                            },
+                            {
+                                "name": "if_modified_since",
+                                "in": "header",
+                                "required": False,
+                                "schema": IsDict(
+                                    {
+                                        "anyOf": [{"type": "string"}, {"type": "null"}],
+                                        "title": "If Modified Since",
+                                    }
+                                )
+                                | IsDict(
+                                    # TODO: remove when deprecating Pydantic v1
+                                    {
+                                        "type": "string",
+                                        "title": "If Modified Since",
+                                    }
+                                ),
+                            },
+                            {
+                                "name": "traceparent",
+                                "in": "header",
+                                "required": False,
+                                "schema": IsDict(
+                                    {
+                                        "anyOf": [{"type": "string"}, {"type": "null"}],
+                                        "title": "Traceparent",
+                                    }
+                                )
+                                | IsDict(
+                                    # TODO: remove when deprecating Pydantic v1
+                                    {
+                                        "type": "string",
+                                        "title": "Traceparent",
+                                    }
+                                ),
+                            },
+                            {
+                                "name": "x_tag",
+                                "in": "header",
+                                "required": False,
+                                "schema": {
+                                    "type": "array",
+                                    "items": {"type": "string"},
+                                    "default": [],
+                                    "title": "X Tag",
+                                },
+                            },
+                        ],
+                        "responses": {
+                            "200": {
+                                "description": "Successful Response",
+                                "content": {"application/json": {"schema": {}}},
+                            },
+                            "422": {
+                                "description": "Validation Error",
+                                "content": {
+                                    "application/json": {
+                                        "schema": {
+                                            "$ref": "#/components/schemas/HTTPValidationError"
+                                        }
+                                    }
+                                },
+                            },
+                        },
+                    }
+                }
+            },
+            "components": {
+                "schemas": {
+                    "HTTPValidationError": {
+                        "properties": {
+                            "detail": {
+                                "items": {
+                                    "$ref": "#/components/schemas/ValidationError"
+                                },
+                                "type": "array",
+                                "title": "Detail",
+                            }
+                        },
+                        "type": "object",
+                        "title": "HTTPValidationError",
+                    },
+                    "ValidationError": {
+                        "properties": {
+                            "loc": {
+                                "items": {
+                                    "anyOf": [{"type": "string"}, {"type": "integer"}]
+                                },
+                                "type": "array",
+                                "title": "Location",
+                            },
+                            "msg": {"type": "string", "title": "Message"},
+                            "type": {"type": "string", "title": "Error Type"},
+                        },
+                        "type": "object",
+                        "required": ["loc", "msg", "type"],
+                        "title": "ValidationError",
+                    },
+                }
+            },
+        }
+    )