]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:recycle: Refactor, fix and update code
authorSebastián Ramírez <tiangolo@gmail.com>
Thu, 6 Dec 2018 16:24:50 +0000 (20:24 +0400)
committerSebastián Ramírez <tiangolo@gmail.com>
Thu, 6 Dec 2018 16:24:50 +0000 (20:24 +0400)
18 files changed:
fastapi/__init__.py
fastapi/applications.py
fastapi/dependencies/__init__.py [new file with mode: 0644]
fastapi/dependencies/models.py [new file with mode: 0644]
fastapi/dependencies/utils.py [new file with mode: 0644]
fastapi/encoders.py [moved from fastapi/pydantic_utils.py with 56% similarity]
fastapi/openapi/__init__.py [new file with mode: 0644]
fastapi/openapi/constants.py [new file with mode: 0644]
fastapi/openapi/models.py [new file with mode: 0644]
fastapi/openapi/utils.py [new file with mode: 0644]
fastapi/params.py
fastapi/routing.py
fastapi/security/api_key.py
fastapi/security/base.py
fastapi/security/http.py
fastapi/security/oauth2.py
fastapi/security/open_id_connect_url.py
fastapi/utils.py [new file with mode: 0644]

index a52bbccf6676c4ffed58ea471da2618aed2d9248..2bb1b27c24ad91424eb86f2ff75f8b1b3c90aa99 100644 (file)
@@ -1,3 +1,3 @@
 """Fast API framework, fast high performance, fast to learn, fast to code"""
 
-__version__ = '0.1'
+__version__ = "0.1"
index 2e1875aa1c380f31fe664edd67997f66ff548a58..3f5a45b73d6d20ee1b77c4fba1a6eda0bec24ddf 100644 (file)
@@ -1,61 +1,19 @@
-import typing
-import inspect
+from typing import Any, Callable, Dict, List, Type
 
 from starlette.applications import Starlette
-from starlette.middleware.lifespan import LifespanMiddleware
+from starlette.exceptions import ExceptionMiddleware, HTTPException
 from starlette.middleware.errors import ServerErrorMiddleware
-from starlette.exceptions import ExceptionMiddleware
-from starlette.responses import JSONResponse, HTMLResponse, PlainTextResponse
-from starlette.requests import Request
-from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
+from starlette.middleware.lifespan import LifespanMiddleware
+from starlette.responses import JSONResponse
 
-from pydantic import BaseModel, BaseConfig, Schema
-from pydantic.utils import lenient_issubclass
-from pydantic.fields import Field
-from pydantic.schema import (
-    field_schema,
-    get_flat_models_from_models,
-    get_flat_models_from_fields,
-    get_model_name_map,
-    schema,
-    model_process_schema,
-)
 
-from .routing import APIRouter, APIRoute, get_openapi_params, get_flat_dependant
-from .pydantic_utils import jsonable_encoder
+from fastapi import routing
+from fastapi.openapi.utils import get_swagger_ui_html, get_openapi, get_redoc_html
 
 
-def docs(openapi_url):
-    return HTMLResponse(
-        """
-    <! doctype html>
-    <html>
-    <head>
-    <link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
-    </head>
-    <body>
-    <div id="swagger-ui">
-    </div>
-    <script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
-    <!-- `SwaggerUIBundle` is now available on the page -->
-    <script>
-            
-    const ui = SwaggerUIBundle({
-        url: '""" + openapi_url + """',
-        dom_id: '#swagger-ui',
-        presets: [
-        SwaggerUIBundle.presets.apis,
-        SwaggerUIBundle.SwaggerUIStandalonePreset
-        ],
-        layout: "BaseLayout"
-    })
-    </script>
-    </body>
-    </html>
-    """,
-        media_type="text/html",
-    )
+async def http_exception(request, exc: HTTPException):
+    print(exc)
+    return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
 
 
 class FastAPI(Starlette):
@@ -67,24 +25,26 @@ class FastAPI(Starlette):
         description: str = "",
         version: str = "0.1.0",
         openapi_url: str = "/openapi.json",
-        docs_url: str = "/docs",
-        **extra: typing.Dict[str, typing.Any],
+        swagger_ui_url: str = "/docs",
+        redoc_url: str = "/redoc",
+        **extra: Dict[str, Any],
     ) -> None:
         self._debug = debug
-        self.router = APIRouter()
+        self.router = routing.APIRouter()
         self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
         self.error_middleware = ServerErrorMiddleware(
             self.exception_middleware, debug=debug
         )
         self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
-        self.schema_generator = None  # type: typing.Optional[BaseSchemaGenerator]
+        self.schema_generator = None
         self.template_env = self.load_template_env(template_directory)
 
         self.title = title
         self.description = description
         self.version = version
         self.openapi_url = openapi_url
-        self.docs_url = docs_url
+        self.swagger_ui_url = swagger_ui_url
+        self.redoc_url = redoc_url
         self.extra = extra
 
         self.openapi_version = "3.0.2"
@@ -93,29 +53,52 @@ class FastAPI(Starlette):
             assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'"
             assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'"
 
-        if self.docs_url:
+        if self.swagger_ui_url or self.redoc_url:
             assert self.openapi_url, "The openapi_url is required for the docs"
+        self.setup()
 
-        self.add_route(
-            self.openapi_url,
-            lambda req: JSONResponse(self.openapi()),
-            include_in_schema=False,
-        )
-        self.add_route(self.docs_url, lambda r: docs(self.openapi_url), include_in_schema=False)
+    def setup(self):
+        if self.openapi_url:
+            self.add_route(
+                self.openapi_url,
+                lambda req: JSONResponse(
+                    get_openapi(
+                        title=self.title,
+                        version=self.version,
+                        openapi_version=self.openapi_version,
+                        description=self.description,
+                        routes=self.routes,
+                    )
+                ),
+                include_in_schema=False,
+            )
+        if self.swagger_ui_url:
+            self.add_route(
+                self.swagger_ui_url,
+                lambda r: get_swagger_ui_html(openapi_url=self.openapi_url, title=self.title + " - Swagger UI"),
+                include_in_schema=False,
+            )
+        if self.redoc_url:
+            self.add_route(
+                self.redoc_url,
+                lambda r: get_redoc_html(openapi_url=self.openapi_url, title=self.title + " - ReDoc"),
+                include_in_schema=False,
+            )
+        self.add_exception_handler(HTTPException, http_exception)
 
     def add_api_route(
         self,
         path: str,
-        endpoint: typing.Callable,
-        methods: typing.List[str] = None,
+        endpoint: Callable,
+        methods: List[str] = None,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -126,7 +109,7 @@ class FastAPI(Starlette):
             methods=methods,
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -140,27 +123,27 @@ class FastAPI(Starlette):
     def api_route(
         self,
         path: str,
-        methods: typing.List[str] = None,
+        methods: List[str] = None,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
-    ) -> typing.Callable:
-        def decorator(func: typing.Callable) -> typing.Callable:
+    ) -> Callable:
+        def decorator(func: Callable) -> Callable:
             self.router.add_api_route(
                 path,
                 func,
                 methods=methods,
                 name=name,
                 include_in_schema=include_in_schema,
-                tags=tags,
+                tags=tags or [],
                 summary=summary,
                 description=description,
                 operation_id=operation_id,
@@ -179,12 +162,12 @@ class FastAPI(Starlette):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -193,7 +176,7 @@ class FastAPI(Starlette):
             path=path,
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -209,12 +192,12 @@ class FastAPI(Starlette):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -223,7 +206,7 @@ class FastAPI(Starlette):
             path=path,
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -239,12 +222,12 @@ class FastAPI(Starlette):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -253,7 +236,7 @@ class FastAPI(Starlette):
             path=path,
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -269,12 +252,12 @@ class FastAPI(Starlette):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -283,7 +266,7 @@ class FastAPI(Starlette):
             path=path,
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -299,12 +282,12 @@ class FastAPI(Starlette):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -313,7 +296,7 @@ class FastAPI(Starlette):
             path=path,
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -329,12 +312,12 @@ class FastAPI(Starlette):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -343,7 +326,7 @@ class FastAPI(Starlette):
             path=path,
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -359,12 +342,12 @@ class FastAPI(Starlette):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -373,7 +356,7 @@ class FastAPI(Starlette):
             path=path,
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -389,12 +372,12 @@ class FastAPI(Starlette):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -403,7 +386,7 @@ class FastAPI(Starlette):
             path=path,
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -413,169 +396,3 @@ class FastAPI(Starlette):
             response_code=response_code,
             response_wrapper=response_wrapper,
         )
-
-    def openapi(self):
-        info = {"title": self.title, "version": self.version}
-        if self.description:
-            info["description"] = self.description
-        output = {"openapi": self.openapi_version, "info": info}
-        components = {}
-        paths = {}
-        methods_with_body = set(("POST", "PUT"))
-        body_fields_from_routes = []
-        responses_from_routes = []
-        ref_prefix = "#/components/schemas/"
-        for route in self.routes:
-            route: APIRoute
-            if route.include_in_schema and isinstance(route, APIRoute):
-                if route.request_body:
-                    assert isinstance(
-                        route.request_body, Field
-                    ), "A request body must be a Pydantic BaseModel or Field"
-                    body_fields_from_routes.append(route.request_body)
-                if route.response_field:
-                    responses_from_routes.append(route.response_field)
-        flat_models = get_flat_models_from_fields(
-            body_fields_from_routes + responses_from_routes
-        )
-        model_name_map = get_model_name_map(flat_models)
-        definitions = {}
-        for model in flat_models:
-            m_schema, m_definitions = model_process_schema(
-                model, model_name_map=model_name_map, ref_prefix=ref_prefix
-            )
-            definitions.update(m_definitions)
-            model_name = model_name_map[model]
-            definitions[model_name] = m_schema
-        validation_error_definition = {
-            "title": "ValidationError",
-            "type": "object",
-            "properties": {
-                "loc": {
-                    "title": "Location",
-                    "type": "array",
-                    "items": {"type": "string"},
-                },
-                "msg": {"title": "Message", "type": "string"},
-                "type": {"title": "Error Type", "type": "string"},
-            },
-            "required": ["loc", "msg", "type"],
-        }
-        validation_error_response_definition = {
-            "title": "HTTPValidationError",
-            "type": "object",
-            "properties": {
-                "detail": {
-                    "title": "Detail",
-                    "type": "array",
-                    "items": {"$ref": ref_prefix + "ValidationError"},
-                }
-            },
-        }
-        for route in self.routes:
-            route: APIRoute
-            if route.include_in_schema and isinstance(route, APIRoute):
-                path = paths.get(route.path, {})
-                for method in route.methods:
-                    operation = {}
-                    if route.tags:
-                        operation["tags"] = route.tags
-                    if route.summary:
-                        operation["summary"] = route.summary
-                    if route.description:
-                        operation["description"] = route.description
-                    if route.operation_id:
-                        operation["operationId"] = route.operation_id
-                    else:
-                        operation["operationId"] = route.name
-                    if route.deprecated:
-                        operation["deprecated"] = route.deprecated
-                    parameters = []
-                    flat_dependant = get_flat_dependant(route.dependant)
-                    security_definitions = {}
-                    for security_scheme in flat_dependant.security_schemes:
-                        security_definition = jsonable_encoder(security_scheme, exclude=("scheme_name",), by_alias=True, include_none=False)
-                        security_name = getattr(security_scheme, "scheme_name", None) or security_scheme.__class__.__name__
-                        security_definitions[security_name] = security_definition
-                    if security_definitions:
-                        components.setdefault("securitySchemes", {}).update(security_definitions)
-                        operation["security"] = [{name: []} for name in security_definitions]
-                    all_route_params = get_openapi_params(route.dependant)
-                    for param in all_route_params:
-                        if "ValidationError" not in definitions:
-                            definitions["ValidationError"] = validation_error_definition
-                            definitions[
-                                "HTTPValidationError"
-                            ] = validation_error_response_definition
-                        parameter = {
-                            "name": param.alias,
-                            "in": param.schema.in_.value,
-                            "required": param.required,
-                            "schema": field_schema(param, model_name_map={})[0],
-                        }
-                        if param.schema.description:
-                            parameter["description"] = param.schema.description
-                        if param.schema.deprecated:
-                            parameter["deprecated"] = param.schema.deprecated
-                        parameters.append(parameter)
-                    if parameters:
-                        operation["parameters"] = parameters
-                    if method in methods_with_body:
-                        request_body = getattr(route, "request_body", None)
-                        if request_body:
-                            assert isinstance(request_body, Field)
-                            body_schema, _ = field_schema(
-                                request_body,
-                                model_name_map=model_name_map,
-                                ref_prefix=ref_prefix,
-                            )
-                            required = request_body.required
-                            request_body_oai = {}
-                            if required:
-                                request_body_oai["required"] = required
-                            request_body_oai["content"] = {
-                                "application/json": {"schema": body_schema}
-                            }
-                            operation["requestBody"] = request_body_oai
-                    response_code = str(route.response_code)
-                    response_schema = {"type": "string"}
-                    if lenient_issubclass(route.response_wrapper, JSONResponse):
-                        response_media_type = "application/json"
-                        if route.response_field:
-                            response_schema, _ = field_schema(
-                                route.response_field,
-                                model_name_map=model_name_map,
-                                ref_prefix=ref_prefix,
-                            )
-                        else:
-                            response_schema = {}
-                    elif lenient_issubclass(route.response_wrapper, HTMLResponse):
-                        response_media_type = "text/html"
-                    else:
-                        response_media_type = "text/plain"
-                    content = {response_media_type: {"schema": response_schema}}
-                    operation["responses"] = {
-                        response_code: {
-                            "description": route.response_description,
-                            "content": content,
-                        }
-                    }
-                    if all_route_params or getattr(route, "request_body", None):
-                        operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
-                            "description": "Validation Error",
-                            "content": {
-                                "application/json": {
-                                    "schema": {
-                                        "$ref": ref_prefix + "HTTPValidationError"
-                                    }
-                                }
-                            },
-                        }
-                    path[method.lower()] = operation
-                paths[route.path] = path
-        if definitions:
-            components.setdefault("schemas", {}).update(definitions)
-        if components:
-            output["components"] = components
-        output["paths"] = paths
-        return output
diff --git a/fastapi/dependencies/__init__.py b/fastapi/dependencies/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/fastapi/dependencies/models.py b/fastapi/dependencies/models.py
new file mode 100644 (file)
index 0000000..ad9419d
--- /dev/null
@@ -0,0 +1,46 @@
+from typing import Any, Callable, Dict, List, Sequence, Tuple
+
+from starlette.concurrency import run_in_threadpool
+from starlette.requests import Request
+
+from fastapi.security.base import SecurityBase
+from pydantic import BaseConfig, Schema
+from pydantic.error_wrappers import ErrorWrapper
+from pydantic.errors import MissingError
+from pydantic.fields import Field, Required
+from pydantic.schema import get_annotation_from_schema
+
+param_supported_types = (str, int, float, bool)
+
+
+class SecurityRequirement:
+    def __init__(self, security_scheme: SecurityBase, scopes: Sequence[str] = None):
+        self.security_scheme = security_scheme
+        self.scopes = scopes
+
+
+class Dependant:
+    def __init__(
+        self,
+        *,
+        path_params: List[Field] = None,
+        query_params: List[Field] = None,
+        header_params: List[Field] = None,
+        cookie_params: List[Field] = None,
+        body_params: List[Field] = None,
+        dependencies: List["Dependant"] = None,
+        security_schemes: List[SecurityRequirement] = None,
+        name: str = None,
+        call: Callable = None,
+        request_param_name: str = None,
+    ) -> None:
+        self.path_params = path_params or []
+        self.query_params = query_params or []
+        self.header_params = header_params or []
+        self.cookie_params = cookie_params or []
+        self.body_params = body_params or []
+        self.dependencies = dependencies or []
+        self.security_requirements = security_schemes or []
+        self.request_param_name = request_param_name
+        self.name = name
+        self.call = call
diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py
new file mode 100644 (file)
index 0000000..6e86de5
--- /dev/null
@@ -0,0 +1,327 @@
+import asyncio
+import inspect
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Tuple
+
+from starlette.concurrency import run_in_threadpool
+from starlette.requests import Request
+
+from fastapi import params
+from fastapi.dependencies.models import Dependant, SecurityRequirement
+from fastapi.security.base import SecurityBase
+from fastapi.utils import get_path_param_names
+from pydantic import BaseConfig, Schema, create_model
+from pydantic.error_wrappers import ErrorWrapper
+from pydantic.errors import MissingError
+from pydantic.fields import Field, Required
+from pydantic.schema import get_annotation_from_schema
+from pydantic.utils import lenient_issubclass
+
+param_supported_types = (str, int, float, bool)
+
+
+def get_sub_dependant(*, param: inspect.Parameter, path: str):
+    depends: params.Depends = param.default
+    if depends.dependency:
+        dependency = depends.dependency
+    else:
+        dependency = param.annotation
+    assert callable(dependency)
+    sub_dependant = get_dependant(path=path, call=dependency, name=param.name)
+    if isinstance(depends, params.Security) and isinstance(dependency, SecurityBase):
+        security_requirement = SecurityRequirement(
+            security_scheme=dependency, scopes=depends.scopes
+        )
+        sub_dependant.security_requirements.append(security_requirement)
+    return sub_dependant
+
+
+def get_flat_dependant(dependant: Dependant):
+    flat_dependant = Dependant(
+        path_params=dependant.path_params.copy(),
+        query_params=dependant.query_params.copy(),
+        header_params=dependant.header_params.copy(),
+        cookie_params=dependant.cookie_params.copy(),
+        body_params=dependant.body_params.copy(),
+        security_schemes=dependant.security_requirements.copy(),
+    )
+    for sub_dependant in dependant.dependencies:
+        if sub_dependant is dependant:
+            raise ValueError("recursion", dependant.dependencies)
+        flat_sub = get_flat_dependant(sub_dependant)
+        flat_dependant.path_params.extend(flat_sub.path_params)
+        flat_dependant.query_params.extend(flat_sub.query_params)
+        flat_dependant.header_params.extend(flat_sub.header_params)
+        flat_dependant.cookie_params.extend(flat_sub.cookie_params)
+        flat_dependant.body_params.extend(flat_sub.body_params)
+        flat_dependant.security_requirements.extend(flat_sub.security_requirements)
+    return flat_dependant
+
+
+def get_dependant(*, path: str, call: Callable, name: str = None):
+    path_param_names = get_path_param_names(path)
+    endpoint_signature = inspect.signature(call)
+    signature_params = endpoint_signature.parameters
+    dependant = Dependant(call=call, name=name)
+    for param_name in signature_params:
+        param = signature_params[param_name]
+        if isinstance(param.default, params.Depends):
+            sub_dependant = get_sub_dependant(param=param, path=path)
+            dependant.dependencies.append(sub_dependant)
+    for param_name in signature_params:
+        param = signature_params[param_name]
+        if (
+            (param.default == param.empty) or isinstance(param.default, params.Path)
+        ) and (param_name in path_param_names):
+            assert lenient_issubclass(
+                param.annotation, param_supported_types
+            ) or param.annotation == param.empty, f"Path params must be of type str, int, float or boot: {param}"
+            param = signature_params[param_name]
+            add_param_to_fields(
+                param=param,
+                dependant=dependant,
+                default_schema=params.Path,
+                force_type=params.ParamTypes.path,
+            )
+        elif (param.default == param.empty or param.default is None) and (
+            param.annotation == param.empty
+            or lenient_issubclass(param.annotation, param_supported_types)
+        ):
+            add_param_to_fields(
+                param=param, dependant=dependant, default_schema=params.Query
+            )
+        elif isinstance(param.default, params.Param):
+            if param.annotation != param.empty:
+                assert lenient_issubclass(
+                    param.annotation, param_supported_types
+                ), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}"
+            add_param_to_fields(
+                param=param, dependant=dependant, default_schema=params.Query
+            )
+        elif lenient_issubclass(param.annotation, Request):
+            dependant.request_param_name = param_name
+        elif not isinstance(param.default, params.Depends):
+            add_param_to_body_fields(param=param, dependant=dependant)
+    return dependant
+
+
+def add_param_to_fields(
+    *,
+    param: inspect.Parameter,
+    dependant: Dependant,
+    default_schema=params.Param,
+    force_type: params.ParamTypes = None,
+):
+    default_value = Required
+    if not param.default == param.empty:
+        default_value = param.default
+    if isinstance(default_value, params.Param):
+        schema = default_value
+        default_value = schema.default
+        if schema.in_ is None:
+            schema.in_ = default_schema.in_
+        if force_type:
+            schema.in_ = force_type
+    else:
+        schema = default_schema(default_value)
+    required = default_value == Required
+    annotation = Any
+    if not param.annotation == param.empty:
+        annotation = param.annotation
+    annotation = get_annotation_from_schema(annotation, schema)
+    field = Field(
+        name=param.name,
+        type_=annotation,
+        default=None if required else default_value,
+        alias=schema.alias or param.name,
+        required=required,
+        model_config=BaseConfig(),
+        class_validators=[],
+        schema=schema,
+    )
+    if schema.in_ == params.ParamTypes.path:
+        dependant.path_params.append(field)
+    elif schema.in_ == params.ParamTypes.query:
+        dependant.query_params.append(field)
+    elif schema.in_ == params.ParamTypes.header:
+        dependant.header_params.append(field)
+    else:
+        assert (
+            schema.in_ == params.ParamTypes.cookie
+        ), f"non-body parameters must be in path, query, header or cookie: {param.name}"
+        dependant.cookie_params.append(field)
+
+
+def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
+    default_value = Required
+    if not param.default == param.empty:
+        default_value = param.default
+    if isinstance(default_value, Schema):
+        schema = default_value
+        default_value = schema.default
+    else:
+        schema = Schema(default_value)
+    required = default_value == Required
+    annotation = get_annotation_from_schema(param.annotation, schema)
+    field = Field(
+        name=param.name,
+        type_=annotation,
+        default=None if required else default_value,
+        alias=schema.alias or param.name,
+        required=required,
+        model_config=BaseConfig,
+        class_validators=[],
+        schema=schema,
+    )
+    dependant.body_params.append(field)
+
+
+def is_coroutine_callable(call: Callable = None):
+    if not call:
+        return False
+    if inspect.isfunction(call):
+        return asyncio.iscoroutinefunction(call)
+    if inspect.isclass(call):
+        return False
+    call = getattr(call, "__call__", None)
+    if not call:
+        return False
+    return asyncio.iscoroutinefunction(call)
+
+
+async def solve_dependencies(
+    *, request: Request, dependant: Dependant, body: Dict[str, Any] = None
+):
+    values: Dict[str, Any] = {}
+    errors: List[ErrorWrapper] = []
+    for sub_dependant in dependant.dependencies:
+        sub_values, sub_errors = await solve_dependencies(
+            request=request, dependant=sub_dependant, body=body
+        )
+        if sub_errors:
+            return {}, errors
+        if sub_dependant.call and is_coroutine_callable(sub_dependant.call):
+            solved = await sub_dependant.call(**sub_values)
+        else:
+            solved = await run_in_threadpool(sub_dependant.call, **sub_values)
+        values[
+            sub_dependant.name
+        ] = solved  # type: ignore # Sub-dependants always have a name
+    path_values, path_errors = request_params_to_args(
+        dependant.path_params, request.path_params
+    )
+    query_values, query_errors = request_params_to_args(
+        dependant.query_params, request.query_params
+    )
+    header_values, header_errors = request_params_to_args(
+        dependant.header_params, request.headers
+    )
+    cookie_values, cookie_errors = request_params_to_args(
+        dependant.cookie_params, request.cookies
+    )
+    values.update(path_values)
+    values.update(query_values)
+    values.update(header_values)
+    values.update(cookie_values)
+    errors = path_errors + query_errors + header_errors + cookie_errors
+    if dependant.body_params:
+        body_values, body_errors = await request_body_to_args(  # type: ignore # body_params checked above
+            dependant.body_params, body
+        )
+        values.update(body_values)
+        errors.extend(body_errors)
+    if dependant.request_param_name:
+        values[dependant.request_param_name] = request
+    return values, errors
+
+
+def request_params_to_args(
+    required_params: List[Field], received_params: Dict[str, Any]
+) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
+    values = {}
+    errors = []
+    for field in required_params:
+        value = received_params.get(field.alias)
+        if value is None:
+            if field.required:
+                errors.append(
+                    ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig)
+                )
+            else:
+                values[field.name] = deepcopy(field.default)
+            continue
+        v_, errors_ = field.validate(
+            value, values, loc=(field.schema.in_.value, field.alias)
+        )
+        if isinstance(errors_, ErrorWrapper):
+            errors.append(errors_)
+        elif isinstance(errors_, list):
+            errors.extend(errors_)
+        else:
+            values[field.name] = v_
+    return values, errors
+
+
+async def request_body_to_args(
+    required_params: List[Field], received_body: Dict[str, Any]
+) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
+    values = {}
+    errors = []
+    if required_params:
+        field = required_params[0]
+        embed = getattr(field.schema, "embed", None)
+        if len(required_params) == 1 and not embed:
+            received_body = {field.alias: received_body}
+        for field in required_params:
+            value = received_body.get(field.alias)
+            if value is None:
+                if field.required:
+                    errors.append(
+                        ErrorWrapper(
+                            MissingError(), loc=("body", field.alias), config=BaseConfig
+                        )
+                    )
+                else:
+                    values[field.name] = deepcopy(field.default)
+                continue
+            v_, errors_ = field.validate(value, values, loc=("body", field.alias))
+            if isinstance(errors_, ErrorWrapper):
+                errors.append(errors_)
+            elif isinstance(errors_, list):
+                errors.extend(errors_)
+            else:
+                values[field.name] = v_
+    return values, errors
+
+
+def get_body_field(*, dependant: Dependant, name: str):
+    flat_dependant = get_flat_dependant(dependant)
+    if not flat_dependant.body_params:
+        return None
+    first_param = flat_dependant.body_params[0]
+    embed = getattr(first_param.schema, "embed", None)
+    if len(flat_dependant.body_params) == 1 and not embed:
+        return first_param
+    model_name = "Body_" + name
+    BodyModel = create_model(model_name)
+    for f in flat_dependant.body_params:
+        BodyModel.__fields__[f.name] = f
+    required = any(True for f in flat_dependant.body_params if f.required)
+    if any(isinstance(f.schema, params.File) for f in flat_dependant.body_params):
+        BodySchema = params.File
+    elif any(isinstance(f.schema, params.Form) for f in flat_dependant.body_params):
+        BodySchema = params.Form
+    else:
+        BodySchema = params.Body
+
+    field = Field(
+        name="body",
+        type_=BodyModel,
+        default=None,
+        required=required,
+        model_config=BaseConfig,
+        class_validators=[],
+        alias="body",
+        schema=BodySchema(None),
+    )
+    return field
similarity index 56%
rename from fastapi/pydantic_utils.py
rename to fastapi/encoders.py
index 8fc6589a4acdcca5de12e36f9c45e4e99d14b715..95ce4479e74b20579d44df680ec2a679737cafbb 100644 (file)
@@ -1,33 +1,44 @@
+from enum import Enum
 from types import GeneratorType
 from typing import Set
+
 from pydantic import BaseModel
-from enum import Enum
 from pydantic.json import pydantic_encoder
 
 
 def jsonable_encoder(
-    obj, include: Set[str] = None, exclude: Set[str] = set(), by_alias: bool = False, include_none=True,
+    obj,
+    include: Set[str] = None,
+    exclude: Set[str] = set(),
+    by_alias: bool = False,
+    include_none=True,
 ):
     if isinstance(obj, BaseModel):
         return jsonable_encoder(
-            obj.dict(include=include, exclude=exclude, by_alias=by_alias), include_none=include_none
+            obj.dict(include=include, exclude=exclude, by_alias=by_alias),
+            include_none=include_none,
         )
-    elif isinstance(obj, Enum):
+    if isinstance(obj, Enum):
         return obj.value
     if isinstance(obj, (str, int, float, type(None))):
         return obj
     if isinstance(obj, dict):
         return {
             jsonable_encoder(
-                key, by_alias=by_alias, include_none=include_none,
-            ): jsonable_encoder(
-                value, by_alias=by_alias, include_none=include_none,
-            )
-            for key, value in obj.items() if value is not None or include_none
+                key, by_alias=by_alias, include_none=include_none
+            ): jsonable_encoder(value, by_alias=by_alias, include_none=include_none)
+            for key, value in obj.items()
+            if value is not None or include_none
         }
     if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
         return [
-            jsonable_encoder(item, include=include, exclude=exclude, by_alias=by_alias, include_none=include_none)
+            jsonable_encoder(
+                item,
+                include=include,
+                exclude=exclude,
+                by_alias=by_alias,
+                include_none=include_none,
+            )
             for item in obj
         ]
     return pydantic_encoder(obj)
diff --git a/fastapi/openapi/__init__.py b/fastapi/openapi/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/fastapi/openapi/constants.py b/fastapi/openapi/constants.py
new file mode 100644 (file)
index 0000000..1d94a33
--- /dev/null
@@ -0,0 +1,2 @@
+METHODS_WITH_BODY = set(("POST", "PUT"))
+REF_PREFIX = "#/components/schemas/"
diff --git a/fastapi/openapi/models.py b/fastapi/openapi/models.py
new file mode 100644 (file)
index 0000000..e3d96bd
--- /dev/null
@@ -0,0 +1,347 @@
+import logging
+from enum import Enum
+from typing import Any, Dict, List, Optional, Union
+
+from pydantic import BaseModel, Schema as PSchema
+from pydantic.types import UrlStr
+
+try:
+    import pydantic.types.EmailStr
+    from pydantic.types import EmailStr
+except ImportError:
+    logging.warning(
+        "email-validator not installed, email fields will be treated as str"
+    )
+
+    class EmailStr(str):
+        pass
+
+
+class Contact(BaseModel):
+    name: Optional[str] = None
+    url: Optional[UrlStr] = None
+    email: Optional[EmailStr] = None
+
+
+class License(BaseModel):
+    name: str
+    url: Optional[UrlStr] = None
+
+
+class Info(BaseModel):
+    title: str
+    description: Optional[str] = None
+    termsOfService: Optional[str] = None
+    contact: Optional[Contact] = None
+    license: Optional[License] = None
+    version: str
+
+
+class ServerVariable(BaseModel):
+    enum: Optional[List[str]] = None
+    default: str
+    description: Optional[str] = None
+
+
+class Server(BaseModel):
+    url: UrlStr
+    description: Optional[str] = None
+    variables: Optional[Dict[str, ServerVariable]] = None
+
+
+class Reference(BaseModel):
+    ref: str = PSchema(..., alias="$ref")
+
+
+class Discriminator(BaseModel):
+    propertyName: str
+    mapping: Optional[Dict[str, str]] = None
+
+
+class XML(BaseModel):
+    name: Optional[str] = None
+    namespace: Optional[str] = None
+    prefix: Optional[str] = None
+    attribute: Optional[bool] = None
+    wrapped: Optional[bool] = None
+
+
+class ExternalDocumentation(BaseModel):
+    description: Optional[str] = None
+    url: UrlStr
+
+
+class SchemaBase(BaseModel):
+    ref: Optional[str] = PSchema(None, alias="$ref")
+    title: Optional[str] = None
+    multipleOf: Optional[float] = None
+    maximum: Optional[float] = None
+    exclusiveMaximum: Optional[float] = None
+    minimum: Optional[float] = None
+    exclusiveMinimum: Optional[float] = None
+    maxLength: Optional[int] = PSchema(None, gte=0)
+    minLength: Optional[int] = PSchema(None, gte=0)
+    pattern: Optional[str] = None
+    maxItems: Optional[int] = PSchema(None, gte=0)
+    minItems: Optional[int] = PSchema(None, gte=0)
+    uniqueItems: Optional[bool] = None
+    maxProperties: Optional[int] = PSchema(None, gte=0)
+    minProperties: Optional[int] = PSchema(None, gte=0)
+    required: Optional[List[str]] = None
+    enum: Optional[List[str]] = None
+    type: Optional[str] = None
+    allOf: Optional[List[Any]] = None
+    oneOf: Optional[List[Any]] = None
+    anyOf: Optional[List[Any]] = None
+    not_: Optional[List[Any]] = PSchema(None, alias="not")
+    items: Optional[Any] = None
+    properties: Optional[Dict[str, Any]] = None
+    additionalProperties: Optional[Union[bool, Any]] = None
+    description: Optional[str] = None
+    format: Optional[str] = None
+    default: Optional[Any] = None
+    nullable: Optional[bool] = None
+    discriminator: Optional[Discriminator] = None
+    readOnly: Optional[bool] = None
+    writeOnly: Optional[bool] = None
+    xml: Optional[XML] = None
+    externalDocs: Optional[ExternalDocumentation] = None
+    example: Optional[Any] = None
+    deprecated: Optional[bool] = None
+
+
+class Schema(SchemaBase):
+    allOf: Optional[List[SchemaBase]] = None
+    oneOf: Optional[List[SchemaBase]] = None
+    anyOf: Optional[List[SchemaBase]] = None
+    not_: Optional[List[SchemaBase]] = PSchema(None, alias="not")
+    items: Optional[SchemaBase] = None
+    properties: Optional[Dict[str, SchemaBase]] = None
+    additionalProperties: Optional[Union[bool, SchemaBase]] = None
+
+
+class Example(BaseModel):
+    summary: Optional[str] = None
+    description: Optional[str] = None
+    value: Optional[Any] = None
+    externalValue: Optional[UrlStr] = None
+
+
+class ParameterInType(Enum):
+    query = "query"
+    header = "header"
+    path = "path"
+    cookie = "cookie"
+
+
+class Encoding(BaseModel):
+    contentType: Optional[str] = None
+    # Workaround OpenAPI recursive reference, using Any
+    headers: Optional[Dict[str, Union[Any, Reference]]] = None
+    style: Optional[str] = None
+    explode: Optional[bool] = None
+    allowReserved: Optional[bool] = None
+
+
+class MediaType(BaseModel):
+    schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
+    example: Optional[Any] = None
+    examples: Optional[Dict[str, Union[Example, Reference]]] = None
+    encoding: Optional[Dict[str, Encoding]] = None
+
+
+class ParameterBase(BaseModel):
+    description: Optional[str] = None
+    required: Optional[bool] = None
+    deprecated: Optional[bool] = None
+    # Serialization rules for simple scenarios
+    style: Optional[str] = None
+    explode: Optional[bool] = None
+    allowReserved: Optional[bool] = None
+    schema_: Optional[Union[Schema, Reference]] = PSchema(None, alias="schema")
+    example: Optional[Any] = None
+    examples: Optional[Dict[str, Union[Example, Reference]]] = None
+    # Serialization rules for more complex scenarios
+    content: Optional[Dict[str, MediaType]] = None
+
+
+class Parameter(ParameterBase):
+    name: str
+    in_: ParameterInType = PSchema(..., alias="in")
+
+
+class Header(ParameterBase):
+    pass
+
+
+# Workaround OpenAPI recursive reference
+class EncodingWithHeaders(Encoding):
+    headers: Optional[Dict[str, Union[Header, Reference]]] = None
+
+
+class RequestBody(BaseModel):
+    description: Optional[str] = None
+    content: Dict[str, MediaType]
+    required: Optional[bool] = None
+
+
+class Link(BaseModel):
+    operationRef: Optional[str] = None
+    operationId: Optional[str] = None
+    parameters: Optional[Dict[str, Union[Any, str]]] = None
+    requestBody: Optional[Union[Any, str]] = None
+    description: Optional[str] = None
+    server: Optional[Server] = None
+
+
+class Response(BaseModel):
+    description: str
+    headers: Optional[Dict[str, Union[Header, Reference]]] = None
+    content: Optional[Dict[str, MediaType]] = None
+    links: Optional[Dict[str, Union[Link, Reference]]] = None
+
+
+class Responses(BaseModel):
+    default: Response
+
+
+class Operation(BaseModel):
+    tags: Optional[List[str]] = None
+    summary: Optional[str] = None
+    description: Optional[str] = None
+    externalDocs: Optional[ExternalDocumentation] = None
+    operationId: Optional[str] = None
+    parameters: Optional[List[Union[Parameter, Reference]]] = None
+    requestBody: Optional[Union[RequestBody, Reference]] = None
+    responses: Union[Responses, Dict[Union[str], Response]]
+    # Workaround OpenAPI recursive reference
+    callbacks: Optional[Dict[str, Union[Dict[str, Any], Reference]]] = None
+    deprecated: Optional[bool] = None
+    security: Optional[List[Dict[str, List[str]]]] = None
+    servers: Optional[List[Server]] = None
+
+
+class PathItem(BaseModel):
+    ref: Optional[str] = PSchema(None, alias="$ref")
+    summary: Optional[str] = None
+    description: Optional[str] = None
+    get: Optional[Operation] = None
+    put: Optional[Operation] = None
+    post: Optional[Operation] = None
+    delete: Optional[Operation] = None
+    options: Optional[Operation] = None
+    head: Optional[Operation] = None
+    patch: Optional[Operation] = None
+    trace: Optional[Operation] = None
+    servers: Optional[List[Server]] = None
+    parameters: Optional[List[Union[Parameter, Reference]]] = None
+
+
+# Workaround OpenAPI recursive reference
+class OperationWithCallbacks(BaseModel):
+    callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None
+
+
+class SecuritySchemeType(Enum):
+    apiKey = "apiKey"
+    http = "http"
+    oauth2 = "oauth2"
+    openIdConnect = "openIdConnect"
+
+
+class SecurityBase(BaseModel):
+    type_: SecuritySchemeType = PSchema(..., alias="type")
+    description: Optional[str] = None
+
+
+class APIKeyIn(Enum):
+    query = "query"
+    header = "header"
+    cookie = "cookie"
+
+
+class APIKey(SecurityBase):
+    type_ = PSchema(SecuritySchemeType.apiKey, alias="type")
+    in_: APIKeyIn = PSchema(..., alias="in")
+    name: str
+
+
+class HTTPBase(SecurityBase):
+    type_ = PSchema(SecuritySchemeType.http, alias="type")
+    scheme: str
+
+
+class HTTPBearer(HTTPBase):
+    scheme = "bearer"
+    bearerFormat: Optional[str] = None
+
+
+class OAuthFlow(BaseModel):
+    refreshUrl: Optional[str] = None
+    scopes: Dict[str, str] = {}
+
+
+class OAuthFlowImplicit(OAuthFlow):
+    authorizationUrl: str
+
+
+class OAuthFlowPassword(OAuthFlow):
+    tokenUrl: str
+
+
+class OAuthFlowClientCredentials(OAuthFlow):
+    tokenUrl: str
+
+
+class OAuthFlowAuthorizationCode(OAuthFlow):
+    authorizationUrl: str
+    tokenUrl: str
+
+
+class OAuthFlows(BaseModel):
+    implicit: Optional[OAuthFlowImplicit] = None
+    password: Optional[OAuthFlowPassword] = None
+    clientCredentials: Optional[OAuthFlowClientCredentials] = None
+    authorizationCode: Optional[OAuthFlowAuthorizationCode] = None
+
+
+class OAuth2(SecurityBase):
+    type_ = PSchema(SecuritySchemeType.oauth2, alias="type")
+    flows: OAuthFlows
+
+
+class OpenIdConnect(SecurityBase):
+    type_ = PSchema(SecuritySchemeType.openIdConnect, alias="type")
+    openIdConnectUrl: str
+
+
+SecurityScheme = Union[APIKey, HTTPBase, HTTPBearer, OAuth2, OpenIdConnect]
+
+
+class Components(BaseModel):
+    schemas: Optional[Dict[str, Union[Schema, Reference]]] = None
+    responses: Optional[Dict[str, Union[Response, Reference]]] = None
+    parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None
+    examples: Optional[Dict[str, Union[Example, Reference]]] = None
+    requestBodies: Optional[Dict[str, Union[RequestBody, Reference]]] = None
+    headers: Optional[Dict[str, Union[Header, Reference]]] = None
+    securitySchemes: Optional[Dict[str, Union[SecurityScheme, Reference]]] = None
+    links: Optional[Dict[str, Union[Link, Reference]]] = None
+    callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference]]] = None
+
+
+class Tag(BaseModel):
+    name: str
+    description: Optional[str] = None
+    externalDocs: Optional[ExternalDocumentation] = None
+
+
+class OpenAPI(BaseModel):
+    openapi: str
+    info: Info
+    servers: Optional[List[Server]] = None
+    paths: Dict[str, PathItem]
+    components: Optional[Components] = None
+    security: Optional[List[Dict[str, List[str]]]] = None
+    tags: Optional[List[Tag]] = None
+    externalDocs: Optional[ExternalDocumentation] = None
diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py
new file mode 100644 (file)
index 0000000..3cf8007
--- /dev/null
@@ -0,0 +1,280 @@
+from typing import Any, Dict, Sequence, Type
+
+from starlette.responses import HTMLResponse, JSONResponse
+from starlette.routing import BaseRoute
+from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
+
+from fastapi import routing
+from fastapi.dependencies.models import Dependant
+from fastapi.dependencies.utils import get_flat_dependant
+from fastapi.encoders import jsonable_encoder
+from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY
+from fastapi.openapi.models import OpenAPI
+from fastapi.params import Body
+from fastapi.utils import get_flat_models_from_routes, get_model_definitions
+from pydantic.fields import Field
+from pydantic.schema import field_schema, get_model_name_map
+from pydantic.utils import lenient_issubclass
+
+validation_error_definition = {
+    "title": "ValidationError",
+    "type": "object",
+    "properties": {
+        "loc": {"title": "Location", "type": "array", "items": {"type": "string"}},
+        "msg": {"title": "Message", "type": "string"},
+        "type": {"title": "Error Type", "type": "string"},
+    },
+    "required": ["loc", "msg", "type"],
+}
+
+validation_error_response_definition = {
+    "title": "HTTPValidationError",
+    "type": "object",
+    "properties": {
+        "detail": {
+            "title": "Detail",
+            "type": "array",
+            "items": {"$ref": REF_PREFIX + "ValidationError"},
+        }
+    },
+}
+
+
+def get_openapi_params(dependant: Dependant):
+    flat_dependant = get_flat_dependant(dependant)
+    return (
+        flat_dependant.path_params
+        + flat_dependant.query_params
+        + flat_dependant.header_params
+        + flat_dependant.cookie_params
+    )
+
+def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
+    if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
+        return None
+    path = {}
+    security_schemes = {}
+    definitions = {}
+    for method in route.methods:
+        operation: Dict[str, Any] = {}
+        if route.tags:
+            operation["tags"] = route.tags
+        if route.summary:
+            operation["summary"] = route.summary
+        if route.description:
+            operation["description"] = route.description
+        if route.operation_id:
+            operation["operationId"] = route.operation_id
+        else:
+            operation["operationId"] = route.name
+        if route.deprecated:
+            operation["deprecated"] = route.deprecated
+        parameters = []
+        flat_dependant = get_flat_dependant(route.dependant)
+        security_definitions = {}
+        for security_requirement in flat_dependant.security_requirements:
+            security_definition = jsonable_encoder(
+                security_requirement.security_scheme,
+                exclude={"scheme_name"},
+                by_alias=True,
+                include_none=False,
+            )
+            security_name = (
+                getattr(
+                    security_requirement.security_scheme, "scheme_name", None
+                )
+                or security_requirement.security_scheme.__class__.__name__
+            )
+            security_definitions[security_name] = security_definition
+            operation.setdefault("security", []).append(
+                {security_name: security_requirement.scopes}
+            )
+        if security_definitions:
+            security_schemes.update(
+                security_definitions
+            )
+        all_route_params = get_openapi_params(route.dependant)
+        for param in all_route_params:
+            if "ValidationError" not in definitions:
+                definitions["ValidationError"] = validation_error_definition
+                definitions[
+                    "HTTPValidationError"
+                ] = validation_error_response_definition
+            parameter = {
+                "name": param.alias,
+                "in": param.schema.in_.value,
+                "required": param.required,
+                "schema": field_schema(param, model_name_map={})[0],
+            }
+            if param.schema.description:
+                parameter["description"] = param.schema.description
+            if param.schema.deprecated:
+                parameter["deprecated"] = param.schema.deprecated
+            parameters.append(parameter)
+        if parameters:
+            operation["parameters"] = parameters
+        if method in METHODS_WITH_BODY:
+            body_field = route.body_field
+            if body_field:
+                assert isinstance(body_field, Field)
+                body_schema, _ = field_schema(
+                    body_field,
+                    model_name_map=model_name_map,
+                    ref_prefix=REF_PREFIX,
+                )
+                if isinstance(body_field.schema, Body):
+                    request_media_type = body_field.schema.media_type
+                else:
+                    # Includes not declared media types (Schema)
+                    request_media_type = "application/json"
+                required = body_field.required
+                request_body_oai = {}
+                if required:
+                    request_body_oai["required"] = required
+                request_body_oai["content"] = {
+                    request_media_type: {"schema": body_schema}
+                }
+                operation["requestBody"] = request_body_oai
+        response_code = str(route.response_code)
+        response_schema = {"type": "string"}
+        if lenient_issubclass(route.response_wrapper, JSONResponse):
+            response_media_type = "application/json"
+            if route.response_field:
+                response_schema, _ = field_schema(
+                    route.response_field,
+                    model_name_map=model_name_map,
+                    ref_prefix=REF_PREFIX,
+                )
+            else:
+                response_schema = {}
+        elif lenient_issubclass(route.response_wrapper, HTMLResponse):
+            response_media_type = "text/html"
+        else:
+            response_media_type = "text/plain"
+        content = {response_media_type: {"schema": response_schema}}
+        operation["responses"] = {
+            response_code: {
+                "description": route.response_description,
+                "content": content,
+            }
+        }
+        if all_route_params or route.body_field:
+            operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
+                "description": "Validation Error",
+                "content": {
+                    "application/json": {
+                        "schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
+                    }
+                },
+            }
+        path[method.lower()] = operation
+    return path, security_schemes, definitions
+
+
+def get_openapi(
+    *,
+    title: str,
+    version: str,
+    openapi_version: str = "3.0.2",
+    description: str = None,
+    routes: Sequence[BaseRoute]
+):
+    info = {"title": title, "version": version}
+    if description:
+        info["description"] = description
+    output = {"openapi": openapi_version, "info": info}
+    components: Dict[str, Dict] = {}
+    paths: Dict[str, Dict] = {}
+    flat_models = get_flat_models_from_routes(routes)
+    model_name_map = get_model_name_map(flat_models)
+    definitions = get_model_definitions(
+        flat_models=flat_models, model_name_map=model_name_map
+    )
+    for route in routes:
+        result = get_openapi_path(route=route, model_name_map=model_name_map)
+        if result:
+            path, security_schemes, path_definitions = result
+            if path:
+                paths.setdefault(route.path, {}).update(path)
+            if security_schemes:
+                components.setdefault("securitySchemes", {}).update(security_schemes)
+            if path_definitions:
+                definitions.update(path_definitions)
+    if definitions:
+        components.setdefault("schemas", {}).update(definitions)
+    if components:
+        output["components"] = components
+    output["paths"] = paths
+    return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False)
+
+
+def get_swagger_ui_html(*, openapi_url: str, title: str):
+    return HTMLResponse(
+        """
+    <! doctype html>
+    <html>
+    <head>
+    <link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
+    <title>
+    """ + title + """
+    </title>
+    </head>
+    <body>
+    <div id="swagger-ui">
+    </div>
+    <script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
+    <!-- `SwaggerUIBundle` is now available on the page -->
+    <script>
+            
+    const ui = SwaggerUIBundle({
+        url: '"""
+        + openapi_url
+        + """',
+        dom_id: '#swagger-ui',
+        presets: [
+        SwaggerUIBundle.presets.apis,
+        SwaggerUIBundle.SwaggerUIStandalonePreset
+        ],
+        layout: "BaseLayout"
+    })
+    </script>
+    </body>
+    </html>
+    """,
+        media_type="text/html",
+    )
+
+
+def get_redoc_html(*, openapi_url: str, title: str):
+    return HTMLResponse(
+        """
+    <!DOCTYPE html>
+<html>
+  <head>
+    <title>
+    """ + title + """
+    </title>
+    <!-- needed for adaptive design -->
+    <meta charset="utf-8"/>
+    <meta name="viewport" content="width=device-width, initial-scale=1">
+    <link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
+
+    <!--
+    ReDoc doesn't change outer page styles
+    -->
+    <style>
+      body {
+        margin: 0;
+        padding: 0;
+      }
+    </style>
+  </head>
+  <body>
+    <redoc spec-url='""" + openapi_url + """'></redoc>
+    <script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
+  </body>
+</html>
+    """,
+        media_type="text/html",
+    )
index 98b80943cce2527a0061dfd199a3cacbf77903ea..abbce8aeb9825101d8168fc91654c1d297bd0944 100644 (file)
@@ -1,5 +1,6 @@
-from typing import Sequence
 from enum import Enum
+from typing import Sequence, Any, Dict
+
 from pydantic import Schema
 
 
@@ -12,6 +13,7 @@ class ParamTypes(Enum):
 
 class Param(Schema):
     in_: ParamTypes
+
     def __init__(
         self,
         default,
@@ -27,7 +29,7 @@ class Param(Schema):
         min_length: int = None,
         max_length: int = None,
         regex: str = None,
-        **extra: object,
+        **extra: Dict[str, Any],
     ):
         self.deprecated = deprecated
         super().__init__(
@@ -64,7 +66,7 @@ class Path(Param):
         min_length: int = None,
         max_length: int = None,
         regex: str = None,
-        **extra: object,
+        **extra: Dict[str, Any],
     ):
         self.description = description
         self.deprecated = deprecated
@@ -103,7 +105,7 @@ class Query(Param):
         min_length: int = None,
         max_length: int = None,
         regex: str = None,
-        **extra: object,
+        **extra: Dict[str, Any],
     ):
         self.description = description
         self.deprecated = deprecated
@@ -141,7 +143,7 @@ class Header(Param):
         min_length: int = None,
         max_length: int = None,
         regex: str = None,
-        **extra: object,
+        **extra: Dict[str, Any],
     ):
         self.description = description
         self.deprecated = deprecated
@@ -179,7 +181,7 @@ class Cookie(Param):
         min_length: int = None,
         max_length: int = None,
         regex: str = None,
-        **extra: object,
+        **extra: Dict[str, Any],
     ):
         self.description = description
         self.deprecated = deprecated
@@ -200,11 +202,49 @@ class Cookie(Param):
 
 
 class Body(Schema):
+    def __init__(
+        self,
+        default,
+        *,
+        embed=False,
+        media_type: str = "application/json",
+        alias: str = None,
+        title: str = None,
+        description: str = None,
+        gt: float = None,
+        ge: float = None,
+        lt: float = None,
+        le: float = None,
+        min_length: int = None,
+        max_length: int = None,
+        regex: str = None,
+        **extra: Dict[str, Any],
+    ):
+        self.embed = embed
+        self.media_type = media_type
+        super().__init__(
+            default,
+            alias=alias,
+            title=title,
+            description=description,
+            gt=gt,
+            ge=ge,
+            lt=lt,
+            le=le,
+            min_length=min_length,
+            max_length=max_length,
+            regex=regex,
+            **extra,
+        )
+
+
+class Form(Body):
     def __init__(
         self,
         default,
         *,
         sub_key=False,
+        media_type: str = "application/x-www-form-urlencoded",
         alias: str = None,
         title: str = None,
         description: str = None,
@@ -215,11 +255,49 @@ class Body(Schema):
         min_length: int = None,
         max_length: int = None,
         regex: str = None,
-        **extra: object,
+        **extra: Dict[str, Any],
     ):
-        self.sub_key = sub_key
         super().__init__(
             default,
+            embed=sub_key,
+            media_type=media_type,
+            alias=alias,
+            title=title,
+            description=description,
+            gt=gt,
+            ge=ge,
+            lt=lt,
+            le=le,
+            min_length=min_length,
+            max_length=max_length,
+            regex=regex,
+            **extra,
+        )
+
+
+class File(Form):
+    def __init__(
+        self,
+        default,
+        *,
+        sub_key=False,
+        media_type: str = "multipart/form-data",
+        alias: str = None,
+        title: str = None,
+        description: str = None,
+        gt: float = None,
+        ge: float = None,
+        lt: float = None,
+        le: float = None,
+        min_length: int = None,
+        max_length: int = None,
+        regex: str = None,
+        **extra: Dict[str, Any],
+    ):
+        super().__init__(
+            default,
+            embed=sub_key,
+            media_type=media_type,
             alias=alias,
             title=title,
             description=description,
@@ -235,12 +313,11 @@ class Body(Schema):
 
 
 class Depends:
-    def __init__(self, dependency = None):
+    def __init__(self, dependency=None):
         self.dependency = dependency
 
 
-class Security:
-    def __init__(self, security_scheme = None, scopes: Sequence[str] = None):
-        self.security_scheme = security_scheme
-        self.scopes = scopes
-
+class Security(Depends):
+    def __init__(self, dependency=None, scopes: Sequence[str] = None):
+        self.scopes = scopes or []
+        super().__init__(dependency=dependency)
index 8c95b3327c61a23447e9996778f98cdb2a1187e2..6f7d592e5a6ad9c436a98b7ca7142537f2582a4d 100644 (file)
 import asyncio
 import inspect
-import re
-import typing
-from copy import deepcopy
+from typing import Callable, List, Type
 
 from starlette import routing
-from starlette.routing import get_name, request_response
-from starlette.requests import Request
-from starlette.responses import Response, JSONResponse
 from starlette.concurrency import run_in_threadpool
 from starlette.exceptions import HTTPException
+from starlette.formparsers import UploadFile
+from starlette.requests import Request
+from starlette.responses import JSONResponse, Response
+from starlette.routing import get_name, request_response
 from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
 
-
-from pydantic.fields import Field, Required
-from pydantic.schema import get_annotation_from_schema
-from pydantic import BaseConfig, BaseModel, create_model, Schema
+from fastapi import params
+from fastapi.dependencies.models import Dependant
+from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies
+from fastapi.encoders import jsonable_encoder
+from pydantic import BaseConfig, BaseModel, Schema
 from pydantic.error_wrappers import ErrorWrapper, ValidationError
-from pydantic.errors import MissingError
+from pydantic.fields import Field
 from pydantic.utils import lenient_issubclass
-from .pydantic_utils import jsonable_encoder
-
-from fastapi import params
-from fastapi.security.base import SecurityBase
-
 
-param_supported_types = (str, int, float, bool)
 
-
-class Dependant:
-    def __init__(
-        self,
-        *,
-        path_params: typing.List[Field] = None,
-        query_params: typing.List[Field] = None,
-        header_params: typing.List[Field] = None,
-        cookie_params: typing.List[Field] = None,
-        body_params: typing.List[Field] = None,
-        dependencies: typing.List["Dependant"] = None,
-        security_schemes: typing.List[Field] = None,
-        name: str = None,
-        call: typing.Callable = None,
-        request_param_name: str = None,
-    ) -> None:
-        self.path_params: typing.List[Field] = path_params or []
-        self.query_params: typing.List[Field] = query_params or []
-        self.header_params: typing.List[Field] = header_params or []
-        self.cookie_params: typing.List[Field] = cookie_params or []
-        self.body_params: typing.List[Field] = body_params or []
-        self.dependencies: typing.List[Dependant] = dependencies or []
-        self.security_schemes: typing.List[Field] = security_schemes or []
-        self.request_param_name = request_param_name
-        self.name = name
-        self.call: typing.Callable = call
-
-
-def request_params_to_args(
-    required_params: typing.List[Field], received_params: typing.Dict[str, typing.Any]
-) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]:
-    values = {}
-    errors = []
-    for field in required_params:
-        value = received_params.get(field.alias)
-        if value is None:
-            if field.required:
-                errors.append(
-                    ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig)
-                )
-            else:
-                values[field.name] = deepcopy(field.default)
-            continue
-        v_, errors_ = field.validate(
-            value, values, loc=(field.schema.in_.value, field.alias)
-        )
+def serialize_response(*, field: Field = None, response):
+    if field:
+        errors = []
+        value, errors_ = field.validate(response, {}, loc=("response",))
         if isinstance(errors_, ErrorWrapper):
-            errors_: ErrorWrapper
             errors.append(errors_)
         elif isinstance(errors_, list):
             errors.extend(errors_)
-        else:
-            values[field.name] = v_
-    return values, errors
-
-
-def request_body_to_args(
-    required_params: typing.List[Field], received_body: typing.Dict[str, typing.Any]
-) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]:
-    values = {}
-    errors = []
-    if required_params:
-        field = required_params[0]
-        sub_key = getattr(field.schema, "sub_key", None)
-        if len(required_params) == 1 and not sub_key:
-            received_body = {field.alias: received_body}
-        for field in required_params:
-            value = received_body.get(field.alias)
-            if value is None:
-                if field.required:
-                    errors.append(
-                        ErrorWrapper(
-                            MissingError(), loc=("body", field.alias), config=BaseConfig
-                        )
-                    )
-                else:
-                    values[field.name] = deepcopy(field.default)
-                continue
-
-            v_, errors_ = field.validate(value, values, loc=("body", field.alias))
-            if isinstance(errors_, ErrorWrapper):
-                errors_: ErrorWrapper
-                errors.append(errors_)
-            elif isinstance(errors_, list):
-                errors.extend(errors_)
-            else:
-                values[field.name] = v_
-    return values, errors
-
-
-def add_param_to_fields(
-    *,
-    param: inspect.Parameter,
-    dependant: Dependant,
-    default_schema=params.Param,
-    force_type: params.ParamTypes = None,
-):
-    default_value = Required
-    if not param.default == param.empty:
-        default_value = param.default
-    if isinstance(default_value, params.Param):
-        schema = default_value
-        default_value = schema.default
-        if schema.in_ is None:
-            schema.in_ = default_schema.in_
-        if force_type:
-            schema.in_ = force_type
-    else:
-        schema = default_schema(default_value)
-    required = default_value == Required
-    annotation = typing.Any
-    if not param.annotation == param.empty:
-        annotation = param.annotation
-    annotation = get_annotation_from_schema(annotation, schema)
-    Config = BaseConfig
-    field = Field(
-        name=param.name,
-        type_=annotation,
-        default=None if required else default_value,
-        alias=schema.alias or param.name,
-        required=required,
-        model_config=Config,
-        class_validators=[],
-        schema=schema,
-    )
-    if schema.in_ == params.ParamTypes.path:
-        dependant.path_params.append(field)
-    elif schema.in_ == params.ParamTypes.query:
-        dependant.query_params.append(field)
-    elif schema.in_ == params.ParamTypes.header:
-        dependant.header_params.append(field)
-    else:
-        assert (
-            schema.in_ == params.ParamTypes.cookie
-        ), f"non-body parameters must be in path, query, header or cookie: {param.name}"
-        dependant.cookie_params.append(field)
-
-
-def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
-    default_value = Required
-    if not param.default == param.empty:
-        default_value = param.default
-    if isinstance(default_value, Schema):
-        schema = default_value
-        default_value = schema.default
+        if errors:
+            raise ValidationError(errors)
+        return jsonable_encoder(value)
     else:
-        schema = Schema(default_value)
-    required = default_value == Required
-    annotation = get_annotation_from_schema(param.annotation, schema)
-    field = Field(
-        name=param.name,
-        type_=annotation,
-        default=None if required else default_value,
-        alias=schema.alias or param.name,
-        required=required,
-        model_config=BaseConfig,
-        class_validators=[],
-        schema=schema,
-    )
-    dependant.body_params.append(field)
+        return jsonable_encoder(response)
 
 
-def get_sub_dependant(
-    *, param: inspect.Parameter, path: str
+def get_app(
+    dependant: Dependant,
+    body_field: Field = None,
+    response_code: str = 200,
+    response_wrapper: Type[Response] = JSONResponse,
+    response_field: Type[Field] = None,
 ):
-    depends: params.Depends = param.default
-    if depends.dependency:
-        dependency = depends.dependency
-    else:
-        dependency = param.annotation
-    assert callable(dependency)
-    sub_dependant = get_dependant(path=path, call=dependency, name=param.name)
-    if isinstance(dependency, SecurityBase):
-        sub_dependant.security_schemes.append(dependency)
-    return sub_dependant
-
-
-def get_flat_dependant(dependant: Dependant):
-    flat_dependant = Dependant(
-        path_params=dependant.path_params.copy(),
-        query_params=dependant.query_params.copy(),
-        header_params=dependant.header_params.copy(),
-        cookie_params=dependant.cookie_params.copy(),
-        body_params=dependant.body_params.copy(),
-        security_schemes=dependant.security_schemes.copy(),
-    )
-    for sub_dependant in dependant.dependencies:
-        if sub_dependant is dependant:
-            raise ValueError("recursion", dependant.dependencies)
-        flat_sub = get_flat_dependant(sub_dependant)
-        flat_dependant.path_params.extend(flat_sub.path_params)
-        flat_dependant.query_params.extend(flat_sub.query_params)
-        flat_dependant.header_params.extend(flat_sub.header_params)
-        flat_dependant.cookie_params.extend(flat_sub.cookie_params)
-        flat_dependant.body_params.extend(flat_sub.body_params)
-        flat_dependant.security_schemes.extend(flat_sub.security_schemes)
-    return flat_dependant
-
-
-def get_path_param_names(path: str):
-    return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
-
-
-def get_dependant(*, path: str, call: typing.Callable, name: str = None):
-    path_param_names = get_path_param_names(path)
-    endpoint_signature = inspect.signature(call)
-    signature_params = endpoint_signature.parameters
-    dependant = Dependant(call=call, name=name)
-    for param_name in signature_params:
-        param = signature_params[param_name]
-        if isinstance(param.default, params.Depends):
-            sub_dependant = get_sub_dependant(param=param, path=path)
-            dependant.dependencies.append(sub_dependant)
-    for param_name in signature_params:
-        param = signature_params[param_name]
-        if (
-            (param.default == param.empty) or isinstance(param.default, params.Path)
-        ) and (param_name in path_param_names):
-            assert lenient_issubclass(
-                param.annotation, param_supported_types
-            ), f"Path params must be of type str, int, float or boot: {param}"
-            param = signature_params[param_name]
-            add_param_to_fields(
-                param=param,
-                dependant=dependant,
-                default_schema=params.Path,
-                force_type=params.ParamTypes.path,
-            )
-        elif (param.default == param.empty or param.default is None) and (
-            param.annotation == param.empty
-            or lenient_issubclass(param.annotation, param_supported_types)
-        ):
-            add_param_to_fields(
-                param=param, dependant=dependant, default_schema=params.Query
-            )
-        elif isinstance(param.default, params.Param):
-            if param.annotation != param.empty:
-                assert lenient_issubclass(
-                    param.annotation, param_supported_types
-                ), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}"
-            add_param_to_fields(
-                param=param, dependant=dependant, default_schema=params.Query
-            )
-        elif lenient_issubclass(param.annotation, Request):
-            dependant.request_param_name = param_name
-        elif not isinstance(param.default, params.Depends):
-            add_param_to_body_fields(param=param, dependant=dependant)
-    return dependant
-
-
-def is_coroutine_callable(call: typing.Callable):
-    if inspect.isfunction(call):
-        return asyncio.iscoroutinefunction(call)
-    elif inspect.isclass(call):
-        return False
-    else:
-        call = getattr(call, "__call__", None)
-        if not call:
-            return False
-        else:
-            return asyncio.iscoroutinefunction(call)
-
-
-async def solve_dependencies(*, request: Request, dependant: Dependant):
-    values = {}
-    errors = []
-    for sub_dependant in dependant.dependencies:
-        sub_values, sub_errors = await solve_dependencies(
-            request=request, dependant=sub_dependant
-        )
-        if sub_errors:
-            return {}, errors
-        if is_coroutine_callable(sub_dependant.call):
-            solved = await sub_dependant.call(**sub_values)
-        else:
-            solved = await run_in_threadpool(sub_dependant.call, **sub_values)
-        values[sub_dependant.name] = solved
-    path_values, path_errors = request_params_to_args(
-        dependant.path_params, request.path_params
-    )
-    query_values, query_errors = request_params_to_args(
-        dependant.query_params, request.query_params
-    )
-    header_values, header_errors = request_params_to_args(
-        dependant.header_params, request.headers
-    )
-    cookie_values, cookie_errors = request_params_to_args(
-        dependant.cookie_params, request.cookies
-    )
-    values.update(path_values)
-    values.update(query_values)
-    values.update(header_values)
-    values.update(cookie_values)
-    errors = path_errors + query_errors + header_errors + cookie_errors
-    if dependant.body_params:
-        body = await request.json()
-        body_values, body_errors = request_body_to_args(dependant.body_params, body)
-        values.update(body_values)
-        errors.extend(body_errors)
-    if dependant.request_param_name:
-        values[dependant.request_param_name] = request
-    return values, errors
-
-
-def get_app(dependant: Dependant):
-    is_coroutine = asyncio.iscoroutinefunction(dependant.call)
+    is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
 
     async def app(request: Request) -> Response:
-        values, errors = await solve_dependencies(request=request, dependant=dependant)
+        body = None
+        if body_field:
+            if isinstance(body_field.schema, params.Form):
+                raw_body = await request.form()
+                body = {}
+                for field, value in raw_body.items():
+                    if isinstance(value, UploadFile):
+                        body[field] = await value.read()
+                    else:
+                        body[field] = value
+            else:
+                body = await request.json()
+        values, errors = await solve_dependencies(
+            request=request, dependant=dependant, body=body
+        )
         if errors:
             errors_out = ValidationError(errors)
             raise HTTPException(
@@ -348,36 +73,56 @@ def get_app(dependant: Dependant):
                 raw_response = await run_in_threadpool(dependant.call, **values)
             if isinstance(raw_response, Response):
                 return raw_response
-            else:
-                return JSONResponse(content=jsonable_encoder(raw_response))
-    return app
-
+            if isinstance(raw_response, BaseModel):
+                return response_wrapper(
+                    content=jsonable_encoder(raw_response), status_code=response_code
+                )
+            errors = []
+            try:
+                return response_wrapper(
+                    content=serialize_response(
+                        field=response_field, response=raw_response
+                    ),
+                    status_code=response_code,
+                )
+            except Exception as e:
+                errors.append(e)
+            try:
+                response = dict(raw_response)
+                return response_wrapper(
+                    content=serialize_response(field=response_field, response=response),
+                    status_code=response_code,
+                )
+            except Exception as e:
+                errors.append(e)
+            try:
+                response = vars(raw_response)
+                return response_wrapper(
+                    content=serialize_response(field=response_field, response=response),
+                    status_code=response_code,
+                )
+            except Exception as e:
+                errors.append(e)
+                raise ValueError(errors)
 
-def get_openapi_params(dependant: Dependant):
-    flat_dependant = get_flat_dependant(dependant)
-    return (
-        flat_dependant.path_params
-        + flat_dependant.query_params
-        + flat_dependant.header_params
-        + flat_dependant.cookie_params
-    )
+    return app
 
 
 class APIRoute(routing.Route):
     def __init__(
         self,
         path: str,
-        endpoint: typing.Callable,
+        endpoint: Callable,
         *,
-        methods: typing.List[str] = None,
+        methods: List[str] = None,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -392,12 +137,12 @@ class APIRoute(routing.Route):
         self.endpoint = endpoint
         self.name = get_name(endpoint) if name is None else name
         self.include_in_schema = include_in_schema
-        self.tags = tags
+        self.tags = tags or []
         self.summary = summary
-        self.description = description
+        self.description = description or self.endpoint.__doc__
         self.operation_id = operation_id
         self.deprecated = deprecated
-        self.request_body: typing.Union[BaseModel, Field, None] = None
+        self.body_field: Field = None
         self.response_description = response_description
         self.response_code = response_code
         self.response_wrapper = response_wrapper
@@ -430,53 +175,32 @@ class APIRoute(routing.Route):
         ), f"An endpoint must be a function or method"
 
         self.dependant = get_dependant(path=path, call=self.endpoint)
-        # flat_dependant = get_flat_dependant(self.dependant)
-        # path_param_names = get_path_param_names(path)
-        # for path_param in path_param_names:
-        #     assert path_param in {
-        #         f.alias for f in flat_dependant.path_params
-        #     }, f"Path parameter must be defined as a function parameter or be defined by a dependency: {path_param}"
-
-        if self.dependant.body_params:
-            first_param = self.dependant.body_params[0]
-            sub_key = getattr(first_param.schema, "sub_key", None)
-            if len(self.dependant.body_params) == 1 and not sub_key:
-                self.request_body = first_param
-            else:
-                model_name = "Body_" + self.name
-                BodyModel = create_model(model_name)
-                for f in self.dependant.body_params:
-                    BodyModel.__fields__[f.name] = f
-                required = any(True for f in self.dependant.body_params if f.required)
-                field = Field(
-                    name="body",
-                    type_=BodyModel,
-                    default=None,
-                    required=required,
-                    model_config=BaseConfig,
-                    class_validators=[],
-                    alias="body",
-                    schema=Schema(None),
-                )
-                self.request_body = field
-
-        self.app = request_response(get_app(dependant=self.dependant))
+        self.body_field = get_body_field(dependant=self.dependant, name=self.name)
+        self.app = request_response(
+            get_app(
+                dependant=self.dependant,
+                body_field=self.body_field,
+                response_code=self.response_code,
+                response_wrapper=self.response_wrapper,
+                response_field=self.response_field,
+            )
+        )
 
 
 class APIRouter(routing.Router):
     def add_api_route(
         self,
         path: str,
-        endpoint: typing.Callable,
-        methods: typing.List[str] = None,
+        endpoint: Callable,
+        methods: List[str] = None,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -487,7 +211,7 @@ class APIRouter(routing.Router):
             methods=methods,
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -502,27 +226,27 @@ class APIRouter(routing.Router):
     def api_route(
         self,
         path: str,
-        methods: typing.List[str] = None,
+        methods: List[str] = None,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
-    ) -> typing.Callable:
-        def decorator(func: typing.Callable) -> typing.Callable:
+    ) -> Callable:
+        def decorator(func: Callable) -> Callable:
             self.add_api_route(
                 path,
                 func,
                 methods=methods,
                 name=name,
                 include_in_schema=include_in_schema,
-                tags=tags,
+                tags=tags or [],
                 summary=summary,
                 description=description,
                 operation_id=operation_id,
@@ -541,12 +265,12 @@ class APIRouter(routing.Router):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -556,7 +280,7 @@ class APIRouter(routing.Router):
             methods=["GET"],
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -572,12 +296,12 @@ class APIRouter(routing.Router):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -587,7 +311,7 @@ class APIRouter(routing.Router):
             methods=["PUT"],
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -603,12 +327,12 @@ class APIRouter(routing.Router):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -618,7 +342,7 @@ class APIRouter(routing.Router):
             methods=["POST"],
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -634,12 +358,12 @@ class APIRouter(routing.Router):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -649,7 +373,7 @@ class APIRouter(routing.Router):
             methods=["DELETE"],
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -665,12 +389,12 @@ class APIRouter(routing.Router):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -680,7 +404,7 @@ class APIRouter(routing.Router):
             methods=["OPTIONS"],
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -696,12 +420,12 @@ class APIRouter(routing.Router):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -711,7 +435,7 @@ class APIRouter(routing.Router):
             methods=["HEAD"],
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -727,12 +451,12 @@ class APIRouter(routing.Router):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -742,7 +466,7 @@ class APIRouter(routing.Router):
             methods=["PATCH"],
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
@@ -758,12 +482,12 @@ class APIRouter(routing.Router):
         path: str,
         name: str = None,
         include_in_schema: bool = True,
-        tags: typing.List[str] = [],
+        tags: List[str] = None,
         summary: str = None,
         description: str = None,
         operation_id: str = None,
         deprecated: bool = None,
-        response_type: typing.Type = None,
+        response_type: Type = None,
         response_description: str = "Successful Response",
         response_code=200,
         response_wrapper=JSONResponse,
@@ -773,7 +497,7 @@ class APIRouter(routing.Router):
             methods=["TRACE"],
             name=name,
             include_in_schema=include_in_schema,
-            tags=tags,
+            tags=tags or [],
             summary=summary,
             description=description,
             operation_id=operation_id,
index 4b6766eb7c86335308edee14cac3858738738d73..c0354fea7ce03da04cb9b7e2335febafc5ad983b 100644 (file)
@@ -1,11 +1,10 @@
-from starlette.requests import Request
+from enum import Enum
 
 from pydantic import Schema
-from enum import Enum
-from .base import SecurityBase, Types
 
-__all__ = ["APIKeyIn", "APIKeyBase", "APIKeyQuery", "APIKeyHeader", "APIKeyCookie"]
+from starlette.requests import Request
 
+from .base import SecurityBase, Types
 
 class APIKeyIn(Enum):
     query = "query"
@@ -21,7 +20,7 @@ class APIKeyBase(SecurityBase):
 
 class APIKeyQuery(APIKeyBase):
     in_ = Schema(APIKeyIn.query, alias="in")
-    
+
     async def __call__(self, requests: Request):
         return requests.query_params.get(self.name)
 
index 37433ff250854fae13dad944378fb6c8ae35f554..9ba430df9154cc3d1504f4e5ff2c31ec7434ea8f 100644 (file)
@@ -1,7 +1,6 @@
 from enum import Enum
-from pydantic import BaseModel, Schema
 
-__all__ = ["Types", "SecurityBase"]
+from pydantic import BaseModel, Schema
 
 
 class Types(Enum):
index aaaf86618c256475c8a029378fcf317fc56c6d2a..7a8bcfe48d2c7830496e2e8ce10cf8adea54889c 100644 (file)
@@ -1,9 +1,8 @@
+from pydantic import Schema
 
 from starlette.requests import Request
-from pydantic import Schema
-from .base import SecurityBase, Types
 
-__all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"]
+from .base import SecurityBase, Types
 
 
 class HTTPBase(SecurityBase):
index a6607ef52bd770dc451f35b26a6ed1f5eab15e37..4febdafc29fcd01ce679dd67933838ea8bf4a9f5 100644 (file)
@@ -3,10 +3,8 @@ from typing import Dict
 from pydantic import BaseModel, Schema
 
 from starlette.requests import Request
-from .base import SecurityBase, Types
-
-# __all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"]
 
+from .base import SecurityBase, Types
 
 class OAuthFlow(BaseModel):
     refreshUrl: str = None
index 2e7791a7acecc70afdd0305c6634d97adf3cd5dc..c84c56de87677ad82260d22b4fa956d906069231 100644 (file)
@@ -2,6 +2,7 @@ from starlette.requests import Request
 
 from .base import SecurityBase, Types
 
+
 class OpenIdConnect(SecurityBase):
     type_ = Types.openIdConnect
     openIdConnectUrl: str
diff --git a/fastapi/utils.py b/fastapi/utils.py
new file mode 100644 (file)
index 0000000..091f868
--- /dev/null
@@ -0,0 +1,46 @@
+import re
+from typing import Dict, Sequence, Set, Type
+
+from starlette.routing import BaseRoute
+
+from fastapi import routing
+from fastapi.openapi.constants import REF_PREFIX
+from pydantic import BaseModel
+from pydantic.fields import Field
+from pydantic.schema import get_flat_models_from_fields, model_process_schema
+
+
+def get_flat_models_from_routes(routes: Sequence[BaseRoute]):
+    body_fields_from_routes = []
+    responses_from_routes = []
+    for route in routes:
+        if route.include_in_schema and isinstance(route, routing.APIRoute):
+            if route.body_field:
+                assert isinstance(
+                    route.body_field, Field
+                ), "A request body must be a Pydantic Field"
+                body_fields_from_routes.append(route.body_field)
+            if route.response_field:
+                responses_from_routes.append(route.response_field)
+    flat_models = get_flat_models_from_fields(
+        body_fields_from_routes + responses_from_routes
+    )
+    return flat_models
+
+
+def get_model_definitions(
+    *, flat_models: Set[Type[BaseModel]], model_name_map: Dict[Type[BaseModel], str]
+):
+    definitions: Dict[str, Dict] = {}
+    for model in flat_models:
+        m_schema, m_definitions = model_process_schema(
+            model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
+        )
+        definitions.update(m_definitions)
+        model_name = model_name_map[model]
+        definitions[model_name] = m_schema
+    return definitions
+
+
+def get_path_param_names(path: str):
+    return {item.strip("{}") for item in re.findall("{[^}]*}", path)}