]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
✨ Add support for multiple Annotated annotations, e.g. `Annotated[str, Field(), Query...
authorSebastián Ramírez <tiangolo@gmail.com>
Tue, 12 Dec 2023 00:22:47 +0000 (00:22 +0000)
committerGitHub <noreply@github.com>
Tue, 12 Dec 2023 00:22:47 +0000 (00:22 +0000)
.github/workflows/test.yml
fastapi/_compat.py
fastapi/dependencies/utils.py
tests/test_ambiguous_params.py
tests/test_annotated.py

index 59754525d742419e987eb3006f68ac1f2e372471..7ebb80efdfc99944d3ef7b3428bf43017cde33ce 100644 (file)
@@ -29,7 +29,7 @@ jobs:
         id: cache
         with:
           path: ${{ env.pythonLocation }}
-          key: ${{ runner.os }}-python-${{ env.pythonLocation }}-pydantic-v2-${{ hashFiles('pyproject.toml', 'requirements-tests.txt', 'requirements-docs-tests.txt') }}-test-v06
+          key: ${{ runner.os }}-python-${{ env.pythonLocation }}-pydantic-v2-${{ hashFiles('pyproject.toml', 'requirements-tests.txt', 'requirements-docs-tests.txt') }}-test-v07
       - name: Install Dependencies
         if: steps.cache.outputs.cache-hit != 'true'
         run: pip install -r requirements-tests.txt
@@ -62,7 +62,7 @@ jobs:
         id: cache
         with:
           path: ${{ env.pythonLocation }}
-          key: ${{ runner.os }}-python-${{ env.pythonLocation }}-${{ matrix.pydantic-version }}-${{ hashFiles('pyproject.toml', 'requirements-tests.txt', 'requirements-docs-tests.txt') }}-test-v06
+          key: ${{ runner.os }}-python-${{ env.pythonLocation }}-${{ matrix.pydantic-version }}-${{ hashFiles('pyproject.toml', 'requirements-tests.txt', 'requirements-docs-tests.txt') }}-test-v07
       - name: Install Dependencies
         if: steps.cache.outputs.cache-hit != 'true'
         run: pip install -r requirements-tests.txt
index fc605d0ec68e383a0dbb51a5b7f26a5f9a15ac1a..35d4a8723113935a37a7d5267b8668662138fe10 100644 (file)
@@ -249,7 +249,12 @@ if PYDANTIC_V2:
         return is_bytes_sequence_annotation(field.type_)
 
     def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
-        return type(field_info).from_annotation(annotation)
+        cls = type(field_info)
+        merged_field_info = cls.from_annotation(annotation)
+        new_field_info = copy(field_info)
+        new_field_info.metadata = merged_field_info.metadata
+        new_field_info.annotation = merged_field_info.annotation
+        return new_field_info
 
     def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
         origin_type = (
index 96e07a45c79db12c49bf41bb27fb9d1cd3e09711..4e88410a5ec1fb525283f62064ed461364526a63 100644 (file)
@@ -325,10 +325,11 @@ def analyze_param(
     field_info = None
     depends = None
     type_annotation: Any = Any
-    if (
-        annotation is not inspect.Signature.empty
-        and get_origin(annotation) is Annotated
-    ):
+    use_annotation: Any = Any
+    if annotation is not inspect.Signature.empty:
+        use_annotation = annotation
+        type_annotation = annotation
+    if get_origin(use_annotation) is Annotated:
         annotated_args = get_args(annotation)
         type_annotation = annotated_args[0]
         fastapi_annotations = [
@@ -336,14 +337,21 @@ def analyze_param(
             for arg in annotated_args[1:]
             if isinstance(arg, (FieldInfo, params.Depends))
         ]
-        assert (
-            len(fastapi_annotations) <= 1
-        ), f"Cannot specify multiple `Annotated` FastAPI arguments for {param_name!r}"
-        fastapi_annotation = next(iter(fastapi_annotations), None)
+        fastapi_specific_annotations = [
+            arg
+            for arg in fastapi_annotations
+            if isinstance(arg, (params.Param, params.Body, params.Depends))
+        ]
+        if fastapi_specific_annotations:
+            fastapi_annotation: Union[
+                FieldInfo, params.Depends, None
+            ] = fastapi_specific_annotations[-1]
+        else:
+            fastapi_annotation = None
         if isinstance(fastapi_annotation, FieldInfo):
             # Copy `field_info` because we mutate `field_info.default` below.
             field_info = copy_field_info(
-                field_info=fastapi_annotation, annotation=annotation
+                field_info=fastapi_annotation, annotation=use_annotation
             )
             assert field_info.default is Undefined or field_info.default is Required, (
                 f"`{field_info.__class__.__name__}` default value cannot be set in"
@@ -356,8 +364,6 @@ def analyze_param(
                 field_info.default = Required
         elif isinstance(fastapi_annotation, params.Depends):
             depends = fastapi_annotation
-    elif annotation is not inspect.Signature.empty:
-        type_annotation = annotation
 
     if isinstance(value, params.Depends):
         assert depends is None, (
@@ -402,15 +408,15 @@ def analyze_param(
             # We might check here that `default_value is Required`, but the fact is that the same
             # parameter might sometimes be a path parameter and sometimes not. See
             # `tests/test_infer_param_optionality.py` for an example.
-            field_info = params.Path(annotation=type_annotation)
+            field_info = params.Path(annotation=use_annotation)
         elif is_uploadfile_or_nonable_uploadfile_annotation(
             type_annotation
         ) or is_uploadfile_sequence_annotation(type_annotation):
-            field_info = params.File(annotation=type_annotation, default=default_value)
+            field_info = params.File(annotation=use_annotation, default=default_value)
         elif not field_annotation_is_scalar(annotation=type_annotation):
-            field_info = params.Body(annotation=type_annotation, default=default_value)
+            field_info = params.Body(annotation=use_annotation, default=default_value)
         else:
-            field_info = params.Query(annotation=type_annotation, default=default_value)
+            field_info = params.Query(annotation=use_annotation, default=default_value)
 
     field = None
     if field_info is not None:
@@ -424,8 +430,8 @@ def analyze_param(
             and getattr(field_info, "in_", None) is None
         ):
             field_info.in_ = params.ParamTypes.query
-        use_annotation = get_annotation_from_field_info(
-            type_annotation,
+        use_annotation_from_field_info = get_annotation_from_field_info(
+            use_annotation,
             field_info,
             param_name,
         )
@@ -436,7 +442,7 @@ def analyze_param(
         field_info.alias = alias
         field = create_response_field(
             name=param_name,
-            type_=use_annotation,
+            type_=use_annotation_from_field_info,
             default=field_info.default,
             alias=alias,
             required=field_info.default in (Required, Undefined),
@@ -466,16 +472,17 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
 
 
 def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
-    field_info = cast(params.Param, field.field_info)
-    if field_info.in_ == params.ParamTypes.path:
+    field_info = field.field_info
+    field_info_in = getattr(field_info, "in_", None)
+    if field_info_in == params.ParamTypes.path:
         dependant.path_params.append(field)
-    elif field_info.in_ == params.ParamTypes.query:
+    elif field_info_in == params.ParamTypes.query:
         dependant.query_params.append(field)
-    elif field_info.in_ == params.ParamTypes.header:
+    elif field_info_in == params.ParamTypes.header:
         dependant.header_params.append(field)
     else:
         assert (
-            field_info.in_ == params.ParamTypes.cookie
+            field_info_in == params.ParamTypes.cookie
         ), f"non-body parameters must be in path, query, header or cookie: {field.name}"
         dependant.cookie_params.append(field)
 
index 42bcc27a1df8faa1af75662503ec7814a98bde46..8a31442eb0ca3fc9bd34ef6664c74fe35f04021e 100644 (file)
@@ -1,6 +1,8 @@
 import pytest
 from fastapi import Depends, FastAPI, Path
 from fastapi.param_functions import Query
+from fastapi.testclient import TestClient
+from fastapi.utils import PYDANTIC_V2
 from typing_extensions import Annotated
 
 app = FastAPI()
@@ -28,18 +30,13 @@ def test_no_annotated_defaults():
             pass  # pragma: nocover
 
 
-def test_no_multiple_annotations():
+def test_multiple_annotations():
     async def dep():
         pass  # pragma: nocover
 
-    with pytest.raises(
-        AssertionError,
-        match="Cannot specify multiple `Annotated` FastAPI arguments for 'foo'",
-    ):
-
-        @app.get("/")
-        async def get(foo: Annotated[int, Query(min_length=1), Query()]):
-            pass  # pragma: nocover
+    @app.get("/multi-query")
+    async def get(foo: Annotated[int, Query(gt=2), Query(lt=10)]):
+        return foo
 
     with pytest.raises(
         AssertionError,
@@ -64,3 +61,15 @@ def test_no_multiple_annotations():
         @app.get("/")
         async def get3(foo: Annotated[int, Query(min_length=1)] = Depends(dep)):
             pass  # pragma: nocover
+
+    client = TestClient(app)
+    response = client.get("/multi-query", params={"foo": "5"})
+    assert response.status_code == 200
+    assert response.json() == 5
+
+    response = client.get("/multi-query", params={"foo": "123"})
+    assert response.status_code == 422
+
+    if PYDANTIC_V2:
+        response = client.get("/multi-query", params={"foo": "1"})
+        assert response.status_code == 422
index 541f84bca1ca7c82cc9eca3af3f08498fad4522d..2222be9783c086d9cf1cfbcce37f16a7650d33c3 100644 (file)
@@ -57,7 +57,7 @@ foo_is_short = {
             {
                 "ctx": {"min_length": 1},
                 "loc": ["query", "foo"],
-                "msg": "String should have at least 1 characters",
+                "msg": "String should have at least 1 character",
                 "type": "string_too_short",
                 "input": "",
                 "url": match_pydantic_error_url("string_too_short"),