]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:sparkles: Refactor, update code, several features
authorSebastián Ramírez <tiangolo@gmail.com>
Thu, 6 Dec 2018 19:47:58 +0000 (23:47 +0400)
committerSebastián Ramírez <tiangolo@gmail.com>
Thu, 6 Dec 2018 19:47:58 +0000 (23:47 +0400)
fastapi/applications.py
fastapi/openapi/docs.py [new file with mode: 0644]
fastapi/openapi/utils.py
fastapi/routing.py
fastapi/security/api_key.py
fastapi/security/base.py
fastapi/security/http.py
fastapi/security/oauth2.py
fastapi/security/open_id_connect_url.py

index 3f5a45b73d6d20ee1b77c4fba1a6eda0bec24ddf..bb21076dfd91b03bf35f65ed78afbe3e302d7f18 100644 (file)
@@ -8,7 +8,8 @@ from starlette.responses import JSONResponse
 
 
 from fastapi import routing
-from fastapi.openapi.utils import get_swagger_ui_html, get_openapi, get_redoc_html
+from fastapi.openapi.utils import get_openapi
+from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
 
 
 async def http_exception(request, exc: HTTPException):
@@ -154,8 +155,10 @@ class FastAPI(Starlette):
                 response_wrapper=response_wrapper,
             )
             return func
-
         return decorator
+    
+    def include_router(self, router: "APIRouter", *, prefix=""):
+        self.router.include_router(router, prefix=prefix)
 
     def get(
         self,
diff --git a/fastapi/openapi/docs.py b/fastapi/openapi/docs.py
new file mode 100644 (file)
index 0000000..c8a1d61
--- /dev/null
@@ -0,0 +1,78 @@
+from starlette.responses import HTMLResponse
+
+def get_swagger_ui_html(*, openapi_url: str, title: str):
+    return HTMLResponse(
+        """
+    <! doctype html>
+    <html>
+    <head>
+    <link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
+    <title>
+    """
+        + title
+        + """
+    </title>
+    </head>
+    <body>
+    <div id="swagger-ui">
+    </div>
+    <script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
+    <!-- `SwaggerUIBundle` is now available on the page -->
+    <script>
+            
+    const ui = SwaggerUIBundle({
+        url: '"""
+        + openapi_url
+        + """',
+        dom_id: '#swagger-ui',
+        presets: [
+        SwaggerUIBundle.presets.apis,
+        SwaggerUIBundle.SwaggerUIStandalonePreset
+        ],
+        layout: "BaseLayout"
+    })
+    </script>
+    </body>
+    </html>
+    """,
+        media_type="text/html",
+    )
+
+
+def get_redoc_html(*, openapi_url: str, title: str):
+    return HTMLResponse(
+        """
+    <!DOCTYPE html>
+<html>
+  <head>
+    <title>
+    """
+        + title
+        + """
+    </title>
+    <!-- needed for adaptive design -->
+    <meta charset="utf-8"/>
+    <meta name="viewport" content="width=device-width, initial-scale=1">
+    <link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
+
+    <!--
+    ReDoc doesn't change outer page styles
+    -->
+    <style>
+      body {
+        margin: 0;
+        padding: 0;
+      }
+    </style>
+  </head>
+  <body>
+    <redoc spec-url='"""
+        + openapi_url
+        + """'></redoc>
+    <script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
+  </body>
+</html>
+    """,
+        media_type="text/html",
+    )
index 3cf800740c119a9b967e7f14876749e766fea1b4..7dbeece737e3e9012bbba01021fa1e6969d03962 100644 (file)
@@ -1,4 +1,8 @@
-from typing import Any, Dict, Sequence, Type
+from typing import Any, Dict, Sequence, Type, List
+
+from pydantic.fields import Field
+from pydantic.schema import field_schema, get_model_name_map
+from pydantic.utils import lenient_issubclass
 
 from starlette.responses import HTMLResponse, JSONResponse
 from starlette.routing import BaseRoute
@@ -12,9 +16,7 @@ from fastapi.openapi.constants import REF_PREFIX, METHODS_WITH_BODY
 from fastapi.openapi.models import OpenAPI
 from fastapi.params import Body
 from fastapi.utils import get_flat_models_from_routes, get_model_definitions
-from pydantic.fields import Field
-from pydantic.schema import field_schema, get_model_name_map
-from pydantic.utils import lenient_issubclass
+
 
 validation_error_definition = {
     "title": "ValidationError",
@@ -49,91 +51,126 @@ def get_openapi_params(dependant: Dependant):
         + flat_dependant.cookie_params
     )
 
+
+def get_openapi_security_definitions(flat_dependant: Dependant):
+    security_definitions = {}
+    operation_security = []
+    for security_requirement in flat_dependant.security_requirements:
+        security_definition = jsonable_encoder(
+            security_requirement.security_scheme.model,
+            by_alias=True,
+            include_none=False,
+        )
+        security_name = (
+            security_requirement.security_scheme.scheme_name
+            
+        )
+        security_definitions[security_name] = security_definition
+        operation_security.append({security_name: security_requirement.scopes})
+    return security_definitions, operation_security
+
+
+def get_openapi_operation_parameters(all_route_params: List[Field]):
+    definitions: Dict[str, Dict] = {}
+    parameters = []
+    for param in all_route_params:
+        if "ValidationError" not in definitions:
+            definitions["ValidationError"] = validation_error_definition
+            definitions["HTTPValidationError"] = validation_error_response_definition
+        parameter = {
+            "name": param.alias,
+            "in": param.schema.in_.value,
+            "required": param.required,
+            "schema": field_schema(param, model_name_map={})[0],
+        }
+        if param.schema.description:
+            parameter["description"] = param.schema.description
+        if param.schema.deprecated:
+            parameter["deprecated"] = param.schema.deprecated
+        parameters.append(parameter)
+    return definitions, parameters
+
+
+def get_openapi_operation_request_body(
+    *, body_field: Field, model_name_map: Dict[Type, str]
+):
+    if not body_field:
+        return None
+    assert isinstance(body_field, Field)
+    body_schema, _ = field_schema(
+        body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
+    )
+    if isinstance(body_field.schema, Body):
+        request_media_type = body_field.schema.media_type
+    else:
+        # Includes not declared media types (Schema)
+        request_media_type = "application/json"
+    required = body_field.required
+    request_body_oai = {}
+    if required:
+        request_body_oai["required"] = required
+    request_body_oai["content"] = {request_media_type: {"schema": body_schema}}
+    return request_body_oai
+
+
+def generate_operation_id(*, route: routing.APIRoute, method: str):
+    if route.operation_id:
+        return route.operation_id
+    path: str = route.path
+    operation_id = route.name + path
+    operation_id = operation_id.replace("{", "_").replace("}", "_").replace("/", "_")
+    operation_id = operation_id + "_" + method.lower()
+    return operation_id
+
+
+def generate_operation_summary(*, route: routing.APIRoute, method: str):
+    if route.summary:
+        return route.summary
+    return method.title() + " " + route.name.replace("_", " ").title()
+
+def get_openapi_operation_metadata(*, route: BaseRoute, method: str):
+    operation: Dict[str, Any] = {}
+    if route.tags:
+        operation["tags"] = route.tags
+    operation["summary"] = generate_operation_summary(route=route, method=method)
+    if route.description:
+        operation["description"] = route.description
+    operation["operationId"] = generate_operation_id(route=route, method=method)
+    if route.deprecated:
+        operation["deprecated"] = route.deprecated
+    return operation
+
+
 def get_openapi_path(*, route: BaseRoute, model_name_map: Dict[Type, str]):
     if not (route.include_in_schema and isinstance(route, routing.APIRoute)):
         return None
     path = {}
-    security_schemes = {}
-    definitions = {}
+    security_schemes: Dict[str, Any] = {}
+    definitions: Dict[str, Any] = {}
     for method in route.methods:
-        operation: Dict[str, Any] = {}
-        if route.tags:
-            operation["tags"] = route.tags
-        if route.summary:
-            operation["summary"] = route.summary
-        if route.description:
-            operation["description"] = route.description
-        if route.operation_id:
-            operation["operationId"] = route.operation_id
-        else:
-            operation["operationId"] = route.name
-        if route.deprecated:
-            operation["deprecated"] = route.deprecated
-        parameters = []
+        operation = get_openapi_operation_metadata(route=route, method=method)
+        parameters: List[Dict] = []
         flat_dependant = get_flat_dependant(route.dependant)
-        security_definitions = {}
-        for security_requirement in flat_dependant.security_requirements:
-            security_definition = jsonable_encoder(
-                security_requirement.security_scheme,
-                exclude={"scheme_name"},
-                by_alias=True,
-                include_none=False,
-            )
-            security_name = (
-                getattr(
-                    security_requirement.security_scheme, "scheme_name", None
-                )
-                or security_requirement.security_scheme.__class__.__name__
-            )
-            security_definitions[security_name] = security_definition
-            operation.setdefault("security", []).append(
-                {security_name: security_requirement.scopes}
-            )
+        security_definitions, operation_security = get_openapi_security_definitions(
+            flat_dependant=flat_dependant
+        )
+        if operation_security:
+            operation.setdefault("security", []).extend(operation_security)
         if security_definitions:
-            security_schemes.update(
-                security_definitions
-            )
+            security_schemes.update(security_definitions)
         all_route_params = get_openapi_params(route.dependant)
-        for param in all_route_params:
-            if "ValidationError" not in definitions:
-                definitions["ValidationError"] = validation_error_definition
-                definitions[
-                    "HTTPValidationError"
-                ] = validation_error_response_definition
-            parameter = {
-                "name": param.alias,
-                "in": param.schema.in_.value,
-                "required": param.required,
-                "schema": field_schema(param, model_name_map={})[0],
-            }
-            if param.schema.description:
-                parameter["description"] = param.schema.description
-            if param.schema.deprecated:
-                parameter["deprecated"] = param.schema.deprecated
-            parameters.append(parameter)
+        validation_definitions, operation_parameters = get_openapi_operation_parameters(
+            all_route_params=all_route_params
+        )
+        definitions.update(validation_definitions)
+        parameters.extend(operation_parameters)
         if parameters:
             operation["parameters"] = parameters
         if method in METHODS_WITH_BODY:
-            body_field = route.body_field
-            if body_field:
-                assert isinstance(body_field, Field)
-                body_schema, _ = field_schema(
-                    body_field,
-                    model_name_map=model_name_map,
-                    ref_prefix=REF_PREFIX,
-                )
-                if isinstance(body_field.schema, Body):
-                    request_media_type = body_field.schema.media_type
-                else:
-                    # Includes not declared media types (Schema)
-                    request_media_type = "application/json"
-                required = body_field.required
-                request_body_oai = {}
-                if required:
-                    request_body_oai["required"] = required
-                request_body_oai["content"] = {
-                    request_media_type: {"schema": body_schema}
-                }
+            request_body_oai = get_openapi_operation_request_body(
+                body_field=route.body_field, model_name_map=model_name_map
+            )
+            if request_body_oai:
                 operation["requestBody"] = request_body_oai
         response_code = str(route.response_code)
         response_schema = {"type": "string"}
@@ -206,75 +243,3 @@ def get_openapi(
         output["components"] = components
     output["paths"] = paths
     return jsonable_encoder(OpenAPI(**output), by_alias=True, include_none=False)
-
-
-def get_swagger_ui_html(*, openapi_url: str, title: str):
-    return HTMLResponse(
-        """
-    <! doctype html>
-    <html>
-    <head>
-    <link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
-    <title>
-    """ + title + """
-    </title>
-    </head>
-    <body>
-    <div id="swagger-ui">
-    </div>
-    <script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
-    <!-- `SwaggerUIBundle` is now available on the page -->
-    <script>
-            
-    const ui = SwaggerUIBundle({
-        url: '"""
-        + openapi_url
-        + """',
-        dom_id: '#swagger-ui',
-        presets: [
-        SwaggerUIBundle.presets.apis,
-        SwaggerUIBundle.SwaggerUIStandalonePreset
-        ],
-        layout: "BaseLayout"
-    })
-    </script>
-    </body>
-    </html>
-    """,
-        media_type="text/html",
-    )
-
-
-def get_redoc_html(*, openapi_url: str, title: str):
-    return HTMLResponse(
-        """
-    <!DOCTYPE html>
-<html>
-  <head>
-    <title>
-    """ + title + """
-    </title>
-    <!-- needed for adaptive design -->
-    <meta charset="utf-8"/>
-    <meta name="viewport" content="width=device-width, initial-scale=1">
-    <link href="https://fonts.googleapis.com/css?family=Montserrat:300,400,700|Roboto:300,400,700" rel="stylesheet">
-
-    <!--
-    ReDoc doesn't change outer page styles
-    -->
-    <style>
-      body {
-        margin: 0;
-        padding: 0;
-      }
-    </style>
-  </head>
-  <body>
-    <redoc spec-url='""" + openapi_url + """'></redoc>
-    <script src="https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js"> </script>
-  </body>
-</html>
-    """,
-        media_type="text/html",
-    )
index 6f7d592e5a6ad9c436a98b7ca7142537f2582a4d..22a62a53a24eed6688b5c01a3754604bd91ed14b 100644 (file)
@@ -2,6 +2,11 @@ import asyncio
 import inspect
 from typing import Callable, List, Type
 
+from pydantic import BaseConfig, BaseModel, Schema
+from pydantic.error_wrappers import ErrorWrapper, ValidationError
+from pydantic.fields import Field
+from pydantic.utils import lenient_issubclass
+
 from starlette import routing
 from starlette.concurrency import run_in_threadpool
 from starlette.exceptions import HTTPException
@@ -15,10 +20,6 @@ from fastapi import params
 from fastapi.dependencies.models import Dependant
 from fastapi.dependencies.utils import get_body_field, get_dependant, solve_dependencies
 from fastapi.encoders import jsonable_encoder
-from pydantic import BaseConfig, BaseModel, Schema
-from pydantic.error_wrappers import ErrorWrapper, ValidationError
-from pydantic.fields import Field
-from pydantic.utils import lenient_issubclass
 
 
 def serialize_response(*, field: Field = None, response):
@@ -44,11 +45,12 @@ def get_app(
     response_field: Type[Field] = None,
 ):
     is_coroutine = dependant.call and asyncio.iscoroutinefunction(dependant.call)
+    is_body_form = body_field and isinstance(body_field.schema, params.Form)
 
     async def app(request: Request) -> Response:
         body = None
         if body_field:
-            if isinstance(body_field.schema, params.Form):
+            if is_body_form:
                 raw_body = await request.form()
                 body = {}
                 for field, value in raw_body.items():
@@ -127,12 +129,7 @@ class APIRoute(routing.Route):
         response_code=200,
         response_wrapper=JSONResponse,
     ) -> None:
-        # TODO define how to read and provide security params, and how to have them globally too
-        # TODO implement dependencies and injection
-        # TODO refactor code structure
-        # TODO create testing
-        # TODO testing coverage
-        assert path.startswith("/"), "Routed paths must always start '/'"
+        assert path.startswith("/"), "Routed paths must always start with '/'"
         self.path = path
         self.endpoint = endpoint
         self.name = get_name(endpoint) if name is None else name
@@ -260,6 +257,39 @@ class APIRouter(routing.Router):
 
         return decorator
 
+    def include_router(self, router: "APIRouter", *, prefix=""):
+        if prefix:
+            assert prefix.startswith("/"), "A path prefix must start with '/'"
+            assert not prefix.endswith(
+                "/"
+            ), "A path prefix must not end with '/', as the routes will start with '/'"
+        for route in router.routes:
+            if isinstance(route, APIRoute):
+                self.add_api_route(
+                    prefix + route.path,
+                    route.endpoint,
+                    methods=route.methods,
+                    name=route.name,
+                    include_in_schema=route.include_in_schema,
+                    tags=route.tags,
+                    summary=route.summary,
+                    description=route.description,
+                    operation_id=route.operation_id,
+                    deprecated=route.deprecated,
+                    response_type=route.response_type,
+                    response_description=route.response_description,
+                    response_code=route.response_code,
+                    response_wrapper=route.response_wrapper,
+                )
+            elif isinstance(route, routing.Route):
+                self.add_route(
+                    prefix + route.path,
+                    route.endpoint,
+                    methods=route.methods,
+                    name=route.name,
+                    include_in_schema=route.include_in_schema,
+                )
+
     def get(
         self,
         path: str,
index c0354fea7ce03da04cb9b7e2335febafc5ad983b..047898dfe21646ecdf0ce3187803f7ca9f22450e 100644 (file)
@@ -1,39 +1,34 @@
-from enum import Enum
-
-from pydantic import Schema
-
 from starlette.requests import Request
 
-from .base import SecurityBase, Types
-
-class APIKeyIn(Enum):
-    query = "query"
-    header = "header"
-    cookie = "cookie"
-
+from .base import SecurityBase
+from fastapi.openapi.models import APIKeyIn, APIKey
 
 class APIKeyBase(SecurityBase):
-    type_ = Schema(Types.apiKey, alias="type")
-    in_: str = Schema(..., alias="in")
-    name: str
-
+    pass
 
 class APIKeyQuery(APIKeyBase):
-    in_ = Schema(APIKeyIn.query, alias="in")
+
+    def __init__(self, *, name: str, scheme_name: str = None):
+        self.model = APIKey(in_=APIKeyIn.query, name=name)
+        self.scheme_name = scheme_name or self.__class__.__name__
 
     async def __call__(self, requests: Request):
-        return requests.query_params.get(self.name)
+        return requests.query_params.get(self.model.name)
 
 
 class APIKeyHeader(APIKeyBase):
-    in_ = Schema(APIKeyIn.header, alias="in")
+    def __init__(self, *, name: str, scheme_name: str = None):
+        self.model = APIKey(in_=APIKeyIn.header, name=name)
+        self.scheme_name = scheme_name or self.__class__.__name__
 
     async def __call__(self, requests: Request):
-        return requests.headers.get(self.name)
+        return requests.headers.get(self.model.name)
 
 
 class APIKeyCookie(APIKeyBase):
-    in_ = Schema(APIKeyIn.cookie, alias="in")
+    def __init__(self, *, name: str, scheme_name: str = None):
+        self.model = APIKey(in_=APIKeyIn.cookie, name=name)
+        self.scheme_name = scheme_name or self.__class__.__name__
 
     async def __call__(self, requests: Request):
-        return requests.cookies.get(self.name)
+        return requests.cookies.get(self.model.name)
index 9ba430df9154cc3d1504f4e5ff2c31ec7434ea8f..8589da0be0a35d2a9444b9b3e44182c415183994 100644 (file)
@@ -1,16 +1,6 @@
-from enum import Enum
+from pydantic import BaseModel
 
-from pydantic import BaseModel, Schema
+from fastapi.openapi.models import SecurityBase as SecurityBaseModel
 
-
-class Types(Enum):
-    apiKey = "apiKey"
-    http = "http"
-    oauth2 = "oauth2"
-    openIdConnect = "openIdConnect"
-
-
-class SecurityBase(BaseModel):
-    scheme_name: str = None
-    type_: Types = Schema(..., alias="type")
-    description: str = None
+class SecurityBase:
+    pass
index 7a8bcfe48d2c7830496e2e8ce10cf8adea54889c..cee42b8687b8b2ff48cfe9d493202baac665a14f 100644 (file)
@@ -1,26 +1,40 @@
-from pydantic import Schema
-
 from starlette.requests import Request
 
-from .base import SecurityBase, Types
+from .base import SecurityBase
+from fastapi.openapi.models import HTTPBase as HTTPBaseModel, HTTPBearer as HTTPBearerModel
 
 
 class HTTPBase(SecurityBase):
-    type_ = Schema(Types.http, alias="type")
-    scheme: str
+    def __init__(self, *, scheme: str, scheme_name: str = None):
+        self.model = HTTPBaseModel(scheme=scheme)
+        self.scheme_name = scheme_name or self.__class__.__name__
 
     async def __call__(self, request: Request):
         return request.headers.get("Authorization")
 
 
 class HTTPBasic(HTTPBase):
-    scheme = "basic"
+    def __init__(self, *, scheme_name: str = None):
+        self.model = HTTPBaseModel(scheme="basic")
+        self.scheme_name = scheme_name or self.__class__.__name__
+    
+    async def __call__(self, request: Request):
+        return request.headers.get("Authorization")
 
 
 class HTTPBearer(HTTPBase):
-    scheme = "bearer"
-    bearerFormat: str = None
+    def __init__(self, *, bearerFormat: str = None, scheme_name: str = None):
+        self.model = HTTPBearerModel(bearerFormat=bearerFormat)
+        self.scheme_name = scheme_name or self.__class__.__name__
+    
+    async def __call__(self, request: Request):
+        return request.headers.get("Authorization")
 
 
 class HTTPDigest(HTTPBase):
-    scheme = "digest"
+    def __init__(self, *, scheme_name: str = None):
+        self.model = HTTPBaseModel(scheme="digest")
+        self.scheme_name = scheme_name or self.__class__.__name__
+    
+    async def __call__(self, request: Request):
+        return request.headers.get("Authorization")
index 4febdafc29fcd01ce679dd67933838ea8bf4a9f5..65517e962de7d9f064d830a952a592973743be76 100644 (file)
@@ -1,43 +1,13 @@
-from typing import Dict
-
-from pydantic import BaseModel, Schema
-
 from starlette.requests import Request
 
-from .base import SecurityBase, Types
-
-class OAuthFlow(BaseModel):
-    refreshUrl: str = None
-    scopes: Dict[str, str] = {}
-
-
-class OAuthFlowImplicit(OAuthFlow):
-    authorizationUrl: str
-
-
-class OAuthFlowPassword(OAuthFlow):
-    tokenUrl: str
-
-
-class OAuthFlowClientCredentials(OAuthFlow):
-    tokenUrl: str
-
-
-class OAuthFlowAuthorizationCode(OAuthFlow):
-    authorizationUrl: str
-    tokenUrl: str
-
-
-class OAuthFlows(BaseModel):
-    implicit: OAuthFlowImplicit = None
-    password: OAuthFlowPassword = None
-    clientCredentials: OAuthFlowClientCredentials = None
-    authorizationCode: OAuthFlowAuthorizationCode = None
+from .base import SecurityBase
+from fastapi.openapi.models import OAuth2 as OAuth2Model, OAuthFlows as OAuthFlowsModel
 
 
 class OAuth2(SecurityBase):
-    type_ = Schema(Types.oauth2, alias="type")
-    flows: OAuthFlows
-
+    def __init__(self, *, flows: OAuthFlowsModel = OAuthFlowsModel(), scheme_name: str = None):
+        self.model = OAuth2Model(flows=flows)
+        self.scheme_name = scheme_name or self.__class__.__name__
+    
     async def __call__(self, request: Request):
         return request.headers.get("Authorization")
index c84c56de87677ad82260d22b4fa956d906069231..49c5aae2d86d800c5c3c188271d91936d612becf 100644 (file)
@@ -1,11 +1,13 @@
 from starlette.requests import Request
 
-from .base import SecurityBase, Types
+from .base import SecurityBase
+from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
 
 
 class OpenIdConnect(SecurityBase):
-    type_ = Types.openIdConnect
-    openIdConnectUrl: str
-
+    def __init__(self, *, openIdConnectUrl: str, scheme_name: str = None):
+        self.model = OpenIdConnectModel(openIdConnectUrl=openIdConnectUrl)
+        self.scheme_name = scheme_name or self.__class__.__name__
+    
     async def __call__(self, request: Request):
         return request.headers.get("Authorization")