]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
🐛 Fix Enum handling with their own schema definitions (#1463)
authorSebastián Ramírez <tiangolo@gmail.com>
Sat, 23 May 2020 16:56:18 +0000 (18:56 +0200)
committerGitHub <noreply@github.com>
Sat, 23 May 2020 16:56:18 +0000 (18:56 +0200)
* 🐛 Fix extra support for enum with its own schema

* ✅ Fix/update test for enum with its own schema

* 🐛 Fix type declarations

* 🔧 Update format and lint scripts to support locally installed Pydantic and Starlette

* 🐛 Add temporary type ignores while enum schemas are merged

fastapi/dependencies/utils.py
fastapi/openapi/utils.py
fastapi/utils.py
scripts/format.sh
scripts/lint.sh
tests/test_tutorial/test_path_params/test_tutorial005.py

index 43ab4a0985bc3a7e5dc061e017aa2772dac05a7d..1a660f5d355faa200be273cd9be4286344a0be11 100644 (file)
@@ -188,6 +188,16 @@ def get_flat_dependant(
     return flat_dependant
 
 
+def get_flat_params(dependant: Dependant) -> List[ModelField]:
+    flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
+    return (
+        flat_dependant.path_params
+        + flat_dependant.query_params
+        + flat_dependant.header_params
+        + flat_dependant.cookie_params
+    )
+
+
 def is_scalar_field(field: ModelField) -> bool:
     field_info = get_field_info(field)
     if not (
index c1e66fc8d29b1dda904716dec45a4540dde7b2aa..b5778327bbad182db019f684dd5ecb0d67d4fd57 100644 (file)
@@ -1,9 +1,10 @@
 import http.client
-from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast
+from enum import Enum
+from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
 
 from fastapi import routing
 from fastapi.dependencies.models import Dependant
-from fastapi.dependencies.utils import get_flat_dependant
+from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
 from fastapi.encoders import jsonable_encoder
 from fastapi.openapi.constants import (
     METHODS_WITH_BODY,
@@ -15,11 +16,14 @@ from fastapi.params import Body, Param
 from fastapi.utils import (
     generate_operation_id_for_path,
     get_field_info,
-    get_flat_models_from_routes,
     get_model_definitions,
 )
 from pydantic import BaseModel
-from pydantic.schema import field_schema, get_model_name_map
+from pydantic.schema import (
+    field_schema,
+    get_flat_models_from_fields,
+    get_model_name_map,
+)
 from pydantic.utils import lenient_issubclass
 from starlette.responses import JSONResponse
 from starlette.routing import BaseRoute
@@ -64,16 +68,6 @@ status_code_ranges: Dict[str, str] = {
 }
 
 
-def get_openapi_params(dependant: Dependant) -> List[ModelField]:
-    flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
-    return (
-        flat_dependant.path_params
-        + flat_dependant.query_params
-        + flat_dependant.header_params
-        + flat_dependant.cookie_params
-    )
-
-
 def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
     security_definitions = {}
     operation_security = []
@@ -90,17 +84,22 @@ def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, L
 
 
 def get_openapi_operation_parameters(
+    *,
     all_route_params: Sequence[ModelField],
+    model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str]
 ) -> List[Dict[str, Any]]:
     parameters = []
     for param in all_route_params:
         field_info = get_field_info(param)
         field_info = cast(Param, field_info)
+        # ignore mypy error until enum schemas are released
         parameter = {
             "name": param.alias,
             "in": field_info.in_.value,
             "required": param.required,
-            "schema": field_schema(param, model_name_map={})[0],
+            "schema": field_schema(
+                param, model_name_map=model_name_map, ref_prefix=REF_PREFIX  # type: ignore
+            )[0],
         }
         if field_info.description:
             parameter["description"] = field_info.description
@@ -111,13 +110,16 @@ def get_openapi_operation_parameters(
 
 
 def get_openapi_operation_request_body(
-    *, body_field: Optional[ModelField], model_name_map: Dict[Type[BaseModel], str]
+    *,
+    body_field: Optional[ModelField],
+    model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str]
 ) -> Optional[Dict]:
     if not body_field:
         return None
     assert isinstance(body_field, ModelField)
+    # ignore mypy error until enum schemas are released
     body_schema, _, _ = field_schema(
-        body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
+        body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX  # type: ignore
     )
     field_info = cast(Body, get_field_info(body_field))
     request_media_type = field_info.media_type
@@ -176,8 +178,10 @@ def get_openapi_path(
                 operation.setdefault("security", []).extend(operation_security)
             if security_definitions:
                 security_schemes.update(security_definitions)
-            all_route_params = get_openapi_params(route.dependant)
-            operation_parameters = get_openapi_operation_parameters(all_route_params)
+            all_route_params = get_flat_params(route.dependant)
+            operation_parameters = get_openapi_operation_parameters(
+                all_route_params=all_route_params, model_name_map=model_name_map
+            )
             parameters.extend(operation_parameters)
             if parameters:
                 operation["parameters"] = list(
@@ -270,6 +274,38 @@ def get_openapi_path(
     return path, security_schemes, definitions
 
 
+def get_flat_models_from_routes(
+    routes: Sequence[BaseRoute],
+) -> Set[Union[Type[BaseModel], Type[Enum]]]:
+    body_fields_from_routes: List[ModelField] = []
+    responses_from_routes: List[ModelField] = []
+    request_fields_from_routes: List[ModelField] = []
+    callback_flat_models: Set[Union[Type[BaseModel], Type[Enum]]] = set()
+    for route in routes:
+        if getattr(route, "include_in_schema", None) and isinstance(
+            route, routing.APIRoute
+        ):
+            if route.body_field:
+                assert isinstance(
+                    route.body_field, ModelField
+                ), "A request body must be a Pydantic Field"
+                body_fields_from_routes.append(route.body_field)
+            if route.response_field:
+                responses_from_routes.append(route.response_field)
+            if route.response_fields:
+                responses_from_routes.extend(route.response_fields.values())
+            if route.callbacks:
+                callback_flat_models |= get_flat_models_from_routes(route.callbacks)
+            params = get_flat_params(route.dependant)
+            request_fields_from_routes.extend(params)
+
+    flat_models = callback_flat_models | get_flat_models_from_fields(
+        body_fields_from_routes + responses_from_routes + request_fields_from_routes,
+        known_models=set(),
+    )
+    return flat_models
+
+
 def get_openapi(
     *,
     title: str,
@@ -286,9 +322,11 @@ def get_openapi(
     components: Dict[str, Dict] = {}
     paths: Dict[str, Dict] = {}
     flat_models = get_flat_models_from_routes(routes)
-    model_name_map = get_model_name_map(flat_models)
+    # ignore mypy error until enum schemas are released
+    model_name_map = get_model_name_map(flat_models)  # type: ignore
+    # ignore mypy error until enum schemas are released
     definitions = get_model_definitions(
-        flat_models=flat_models, model_name_map=model_name_map
+        flat_models=flat_models, model_name_map=model_name_map  # type: ignore
     )
     for route in routes:
         if isinstance(route, routing.APIRoute):
index 154dd9aa180e3e9e0a1dbf247fb10d1a8d559581..c9022fbc3b43b3cb6565c7d4b4e57219d85da823 100644 (file)
@@ -1,17 +1,16 @@
 import functools
 import re
 from dataclasses import is_dataclass
-from typing import Any, Dict, List, Optional, Sequence, Set, Type, Union, cast
+from enum import Enum
+from typing import Any, Dict, Optional, Set, Type, Union, cast
 
 import fastapi
-from fastapi import routing
 from fastapi.logger import logger
 from fastapi.openapi.constants import REF_PREFIX
 from pydantic import BaseConfig, BaseModel, create_model
 from pydantic.class_validators import Validator
-from pydantic.schema import get_flat_models_from_fields, model_process_schema
+from pydantic.schema import model_process_schema
 from pydantic.utils import lenient_issubclass
-from starlette.routing import BaseRoute
 
 try:
     from pydantic.fields import FieldInfo, ModelField, UndefinedType
@@ -50,38 +49,16 @@ def warning_response_model_skip_defaults_deprecated() -> None:
     )
 
 
-def get_flat_models_from_routes(routes: Sequence[BaseRoute]) -> Set[Type[BaseModel]]:
-    body_fields_from_routes: List[ModelField] = []
-    responses_from_routes: List[ModelField] = []
-    callback_flat_models: Set[Type[BaseModel]] = set()
-    for route in routes:
-        if getattr(route, "include_in_schema", None) and isinstance(
-            route, routing.APIRoute
-        ):
-            if route.body_field:
-                assert isinstance(
-                    route.body_field, ModelField
-                ), "A request body must be a Pydantic Field"
-                body_fields_from_routes.append(route.body_field)
-            if route.response_field:
-                responses_from_routes.append(route.response_field)
-            if route.response_fields:
-                responses_from_routes.extend(route.response_fields.values())
-            if route.callbacks:
-                callback_flat_models |= get_flat_models_from_routes(route.callbacks)
-    flat_models = callback_flat_models | get_flat_models_from_fields(
-        body_fields_from_routes + responses_from_routes, known_models=set()
-    )
-    return flat_models
-
-
 def get_model_definitions(
-    *, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
+    *,
+    flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
+    model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
 ) -> Dict[str, Any]:
     definitions: Dict[str, Dict] = {}
     for model in flat_models:
+        # ignore mypy error until enum schemas are released
         m_schema, m_definitions, m_nested_models = model_process_schema(
-            model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
+            model, model_name_map=model_name_map, ref_prefix=REF_PREFIX  # type: ignore
         )
         definitions.update(m_definitions)
         model_name = model_name_map[model]
index bbcb04354b01afdf985ddf457f99ef9f0746b648..07ce78f699c0ad8b600a69410452c19636fe97f5 100755 (executable)
@@ -3,4 +3,4 @@ set -x
 
 autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place docs_src fastapi tests scripts --exclude=__init__.py
 black fastapi tests docs_src scripts
-isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --thirdparty fastapi --apply fastapi tests docs_src scripts
+isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --thirdparty fastapi --thirdparty pydantic --thirdparty starlette --apply fastapi tests docs_src scripts
index 6472f18454adc3f86d8dc6bb7baccc7e2b78d69c..ec0f7f41bfcf54d02f3ff17e6e38ddab0feac9cd 100755 (executable)
@@ -5,4 +5,4 @@ set -x
 
 mypy fastapi
 black fastapi tests --check
-isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --check-only --thirdparty fastapi fastapi tests
+isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --check-only --thirdparty fastapi --thirdparty fastapi --thirdparty pydantic --thirdparty starlette fastapi tests
index b0e0535e8a627099de5f396fd180570350756168..836a6264b22bc638d4e66418385da4931e5fc44e 100644 (file)
@@ -87,7 +87,7 @@ openapi_schema2 = {
                 "parameters": [
                     {
                         "required": True,
-                        "schema": {"$ref": "#/definitions/ModelName"},
+                        "schema": {"$ref": "#/components/schemas/ModelName"},
                         "name": "model_name",
                         "in": "path",
                     }
@@ -124,6 +124,12 @@ openapi_schema2 = {
                     }
                 },
             },
+            "ModelName": {
+                "title": "ModelName",
+                "enum": ["alexnet", "resnet", "lenet"],
+                "type": "string",
+                "description": "An enumeration.",
+            },
             "ValidationError": {
                 "title": "ValidationError",
                 "required": ["loc", "msg", "type"],