]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:tada: Start tracking messy initial stage
authorSebastián Ramírez <tiangolo@gmail.com>
Wed, 5 Dec 2018 06:56:50 +0000 (10:56 +0400)
committerSebastián Ramírez <tiangolo@gmail.com>
Wed, 5 Dec 2018 06:56:50 +0000 (10:56 +0400)
...before refactoring and breaking something

fastapi/__init__.py [new file with mode: 0644]
fastapi/applications.py [new file with mode: 0644]
fastapi/params.py [new file with mode: 0644]
fastapi/pydantic_utils.py [new file with mode: 0644]
fastapi/routing.py [new file with mode: 0644]
fastapi/security/__init__.py [new file with mode: 0644]
fastapi/security/api_key.py [new file with mode: 0644]
fastapi/security/base.py [new file with mode: 0644]
fastapi/security/http.py [new file with mode: 0644]
fastapi/security/oauth2.py [new file with mode: 0644]
fastapi/security/open_id_connect_url.py [new file with mode: 0644]

diff --git a/fastapi/__init__.py b/fastapi/__init__.py
new file mode 100644 (file)
index 0000000..a52bbcc
--- /dev/null
@@ -0,0 +1,3 @@
+"""Fast API framework, fast high performance, fast to learn, fast to code"""
+
+__version__ = '0.1'
diff --git a/fastapi/applications.py b/fastapi/applications.py
new file mode 100644 (file)
index 0000000..2e1875a
--- /dev/null
@@ -0,0 +1,581 @@
+import typing
+import inspect
+
+from starlette.applications import Starlette
+from starlette.middleware.lifespan import LifespanMiddleware
+from starlette.middleware.errors import ServerErrorMiddleware
+from starlette.exceptions import ExceptionMiddleware
+from starlette.responses import JSONResponse, HTMLResponse, PlainTextResponse
+from starlette.requests import Request
+from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
+
+from pydantic import BaseModel, BaseConfig, Schema
+from pydantic.utils import lenient_issubclass
+from pydantic.fields import Field
+from pydantic.schema import (
+    field_schema,
+    get_flat_models_from_models,
+    get_flat_models_from_fields,
+    get_model_name_map,
+    schema,
+    model_process_schema,
+)
+
+from .routing import APIRouter, APIRoute, get_openapi_params, get_flat_dependant
+from .pydantic_utils import jsonable_encoder
+
+
+def docs(openapi_url):
+    return HTMLResponse(
+        """
+    <! doctype html>
+    <html>
+    <head>
+    <link type="text/css" rel="stylesheet" href="//unpkg.com/swagger-ui-dist@3/swagger-ui.css">
+    </head>
+    <body>
+    <div id="swagger-ui">
+    </div>
+    <script src="//unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
+    <!-- `SwaggerUIBundle` is now available on the page -->
+    <script>
+            
+    const ui = SwaggerUIBundle({
+        url: '""" + openapi_url + """',
+        dom_id: '#swagger-ui',
+        presets: [
+        SwaggerUIBundle.presets.apis,
+        SwaggerUIBundle.SwaggerUIStandalonePreset
+        ],
+        layout: "BaseLayout"
+    })
+    </script>
+    </body>
+    </html>
+    """,
+        media_type="text/html",
+    )
+
+
+class FastAPI(Starlette):
+    def __init__(
+        self,
+        debug: bool = False,
+        template_directory: str = None,
+        title: str = "Fast API",
+        description: str = "",
+        version: str = "0.1.0",
+        openapi_url: str = "/openapi.json",
+        docs_url: str = "/docs",
+        **extra: typing.Dict[str, typing.Any],
+    ) -> None:
+        self._debug = debug
+        self.router = APIRouter()
+        self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
+        self.error_middleware = ServerErrorMiddleware(
+            self.exception_middleware, debug=debug
+        )
+        self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
+        self.schema_generator = None  # type: typing.Optional[BaseSchemaGenerator]
+        self.template_env = self.load_template_env(template_directory)
+
+        self.title = title
+        self.description = description
+        self.version = version
+        self.openapi_url = openapi_url
+        self.docs_url = docs_url
+        self.extra = extra
+
+        self.openapi_version = "3.0.2"
+
+        if self.openapi_url:
+            assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'"
+            assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'"
+
+        if self.docs_url:
+            assert self.openapi_url, "The openapi_url is required for the docs"
+
+        self.add_route(
+            self.openapi_url,
+            lambda req: JSONResponse(self.openapi()),
+            include_in_schema=False,
+        )
+        self.add_route(self.docs_url, lambda r: docs(self.openapi_url), include_in_schema=False)
+
+    def add_api_route(
+        self,
+        path: str,
+        endpoint: typing.Callable,
+        methods: typing.List[str] = None,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ) -> None:
+        self.router.add_api_route(
+            path,
+            endpoint=endpoint,
+            methods=methods,
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def api_route(
+        self,
+        path: str,
+        methods: typing.List[str] = None,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ) -> typing.Callable:
+        def decorator(func: typing.Callable) -> typing.Callable:
+            self.router.add_api_route(
+                path,
+                func,
+                methods=methods,
+                name=name,
+                include_in_schema=include_in_schema,
+                tags=tags,
+                summary=summary,
+                description=description,
+                operation_id=operation_id,
+                deprecated=deprecated,
+                response_type=response_type,
+                response_description=response_description,
+                response_code=response_code,
+                response_wrapper=response_wrapper,
+            )
+            return func
+
+        return decorator
+
+    def get(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.router.get(
+            path=path,
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def put(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.router.put(
+            path=path,
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def post(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.router.post(
+            path=path,
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def delete(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.router.delete(
+            path=path,
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def options(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.router.options(
+            path=path,
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def head(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.router.head(
+            path=path,
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def patch(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.router.patch(
+            path=path,
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def trace(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.router.trace(
+            path=path,
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def openapi(self):
+        info = {"title": self.title, "version": self.version}
+        if self.description:
+            info["description"] = self.description
+        output = {"openapi": self.openapi_version, "info": info}
+        components = {}
+        paths = {}
+        methods_with_body = set(("POST", "PUT"))
+        body_fields_from_routes = []
+        responses_from_routes = []
+        ref_prefix = "#/components/schemas/"
+        for route in self.routes:
+            route: APIRoute
+            if route.include_in_schema and isinstance(route, APIRoute):
+                if route.request_body:
+                    assert isinstance(
+                        route.request_body, Field
+                    ), "A request body must be a Pydantic BaseModel or Field"
+                    body_fields_from_routes.append(route.request_body)
+                if route.response_field:
+                    responses_from_routes.append(route.response_field)
+        flat_models = get_flat_models_from_fields(
+            body_fields_from_routes + responses_from_routes
+        )
+        model_name_map = get_model_name_map(flat_models)
+        definitions = {}
+        for model in flat_models:
+            m_schema, m_definitions = model_process_schema(
+                model, model_name_map=model_name_map, ref_prefix=ref_prefix
+            )
+            definitions.update(m_definitions)
+            model_name = model_name_map[model]
+            definitions[model_name] = m_schema
+        validation_error_definition = {
+            "title": "ValidationError",
+            "type": "object",
+            "properties": {
+                "loc": {
+                    "title": "Location",
+                    "type": "array",
+                    "items": {"type": "string"},
+                },
+                "msg": {"title": "Message", "type": "string"},
+                "type": {"title": "Error Type", "type": "string"},
+            },
+            "required": ["loc", "msg", "type"],
+        }
+        validation_error_response_definition = {
+            "title": "HTTPValidationError",
+            "type": "object",
+            "properties": {
+                "detail": {
+                    "title": "Detail",
+                    "type": "array",
+                    "items": {"$ref": ref_prefix + "ValidationError"},
+                }
+            },
+        }
+        for route in self.routes:
+            route: APIRoute
+            if route.include_in_schema and isinstance(route, APIRoute):
+                path = paths.get(route.path, {})
+                for method in route.methods:
+                    operation = {}
+                    if route.tags:
+                        operation["tags"] = route.tags
+                    if route.summary:
+                        operation["summary"] = route.summary
+                    if route.description:
+                        operation["description"] = route.description
+                    if route.operation_id:
+                        operation["operationId"] = route.operation_id
+                    else:
+                        operation["operationId"] = route.name
+                    if route.deprecated:
+                        operation["deprecated"] = route.deprecated
+                    parameters = []
+                    flat_dependant = get_flat_dependant(route.dependant)
+                    security_definitions = {}
+                    for security_scheme in flat_dependant.security_schemes:
+                        security_definition = jsonable_encoder(security_scheme, exclude=("scheme_name",), by_alias=True, include_none=False)
+                        security_name = getattr(security_scheme, "scheme_name", None) or security_scheme.__class__.__name__
+                        security_definitions[security_name] = security_definition
+                    if security_definitions:
+                        components.setdefault("securitySchemes", {}).update(security_definitions)
+                        operation["security"] = [{name: []} for name in security_definitions]
+                    all_route_params = get_openapi_params(route.dependant)
+                    for param in all_route_params:
+                        if "ValidationError" not in definitions:
+                            definitions["ValidationError"] = validation_error_definition
+                            definitions[
+                                "HTTPValidationError"
+                            ] = validation_error_response_definition
+                        parameter = {
+                            "name": param.alias,
+                            "in": param.schema.in_.value,
+                            "required": param.required,
+                            "schema": field_schema(param, model_name_map={})[0],
+                        }
+                        if param.schema.description:
+                            parameter["description"] = param.schema.description
+                        if param.schema.deprecated:
+                            parameter["deprecated"] = param.schema.deprecated
+                        parameters.append(parameter)
+                    if parameters:
+                        operation["parameters"] = parameters
+                    if method in methods_with_body:
+                        request_body = getattr(route, "request_body", None)
+                        if request_body:
+                            assert isinstance(request_body, Field)
+                            body_schema, _ = field_schema(
+                                request_body,
+                                model_name_map=model_name_map,
+                                ref_prefix=ref_prefix,
+                            )
+                            required = request_body.required
+                            request_body_oai = {}
+                            if required:
+                                request_body_oai["required"] = required
+                            request_body_oai["content"] = {
+                                "application/json": {"schema": body_schema}
+                            }
+                            operation["requestBody"] = request_body_oai
+                    response_code = str(route.response_code)
+                    response_schema = {"type": "string"}
+                    if lenient_issubclass(route.response_wrapper, JSONResponse):
+                        response_media_type = "application/json"
+                        if route.response_field:
+                            response_schema, _ = field_schema(
+                                route.response_field,
+                                model_name_map=model_name_map,
+                                ref_prefix=ref_prefix,
+                            )
+                        else:
+                            response_schema = {}
+                    elif lenient_issubclass(route.response_wrapper, HTMLResponse):
+                        response_media_type = "text/html"
+                    else:
+                        response_media_type = "text/plain"
+                    content = {response_media_type: {"schema": response_schema}}
+                    operation["responses"] = {
+                        response_code: {
+                            "description": route.response_description,
+                            "content": content,
+                        }
+                    }
+                    if all_route_params or getattr(route, "request_body", None):
+                        operation["responses"][str(HTTP_422_UNPROCESSABLE_ENTITY)] = {
+                            "description": "Validation Error",
+                            "content": {
+                                "application/json": {
+                                    "schema": {
+                                        "$ref": ref_prefix + "HTTPValidationError"
+                                    }
+                                }
+                            },
+                        }
+                    path[method.lower()] = operation
+                paths[route.path] = path
+        if definitions:
+            components.setdefault("schemas", {}).update(definitions)
+        if components:
+            output["components"] = components
+        output["paths"] = paths
+        return output
diff --git a/fastapi/params.py b/fastapi/params.py
new file mode 100644 (file)
index 0000000..98b8094
--- /dev/null
@@ -0,0 +1,246 @@
+from typing import Sequence
+from enum import Enum
+from pydantic import Schema
+
+
+class ParamTypes(Enum):
+    query = "query"
+    header = "header"
+    path = "path"
+    cookie = "cookie"
+
+
+class Param(Schema):
+    in_: ParamTypes
+    def __init__(
+        self,
+        default,
+        *,
+        deprecated: bool = None,
+        alias: str = None,
+        title: str = None,
+        description: str = None,
+        gt: float = None,
+        ge: float = None,
+        lt: float = None,
+        le: float = None,
+        min_length: int = None,
+        max_length: int = None,
+        regex: str = None,
+        **extra: object,
+    ):
+        self.deprecated = deprecated
+        super().__init__(
+            default,
+            alias=alias,
+            title=title,
+            description=description,
+            gt=gt,
+            ge=ge,
+            lt=lt,
+            le=le,
+            min_length=min_length,
+            max_length=max_length,
+            regex=regex,
+            **extra,
+        )
+
+
+class Path(Param):
+    in_ = ParamTypes.path
+
+    def __init__(
+        self,
+        default,
+        *,
+        deprecated: bool = None,
+        alias: str = None,
+        title: str = None,
+        description: str = None,
+        gt: float = None,
+        ge: float = None,
+        lt: float = None,
+        le: float = None,
+        min_length: int = None,
+        max_length: int = None,
+        regex: str = None,
+        **extra: object,
+    ):
+        self.description = description
+        self.deprecated = deprecated
+        self.in_ = self.in_
+        super().__init__(
+            default,
+            alias=alias,
+            title=title,
+            description=description,
+            gt=gt,
+            ge=ge,
+            lt=lt,
+            le=le,
+            min_length=min_length,
+            max_length=max_length,
+            regex=regex,
+            **extra,
+        )
+
+
+class Query(Param):
+    in_ = ParamTypes.query
+
+    def __init__(
+        self,
+        default,
+        *,
+        deprecated: bool = None,
+        alias: str = None,
+        title: str = None,
+        description: str = None,
+        gt: float = None,
+        ge: float = None,
+        lt: float = None,
+        le: float = None,
+        min_length: int = None,
+        max_length: int = None,
+        regex: str = None,
+        **extra: object,
+    ):
+        self.description = description
+        self.deprecated = deprecated
+        super().__init__(
+            default,
+            alias=alias,
+            title=title,
+            description=description,
+            gt=gt,
+            ge=ge,
+            lt=lt,
+            le=le,
+            min_length=min_length,
+            max_length=max_length,
+            regex=regex,
+            **extra,
+        )
+
+
+class Header(Param):
+    in_ = ParamTypes.header
+
+    def __init__(
+        self,
+        default,
+        *,
+        deprecated: bool = None,
+        alias: str = None,
+        title: str = None,
+        description: str = None,
+        gt: float = None,
+        ge: float = None,
+        lt: float = None,
+        le: float = None,
+        min_length: int = None,
+        max_length: int = None,
+        regex: str = None,
+        **extra: object,
+    ):
+        self.description = description
+        self.deprecated = deprecated
+        super().__init__(
+            default,
+            alias=alias,
+            title=title,
+            description=description,
+            gt=gt,
+            ge=ge,
+            lt=lt,
+            le=le,
+            min_length=min_length,
+            max_length=max_length,
+            regex=regex,
+            **extra,
+        )
+
+
+class Cookie(Param):
+    in_ = ParamTypes.cookie
+
+    def __init__(
+        self,
+        default,
+        *,
+        deprecated: bool = None,
+        alias: str = None,
+        title: str = None,
+        description: str = None,
+        gt: float = None,
+        ge: float = None,
+        lt: float = None,
+        le: float = None,
+        min_length: int = None,
+        max_length: int = None,
+        regex: str = None,
+        **extra: object,
+    ):
+        self.description = description
+        self.deprecated = deprecated
+        super().__init__(
+            default,
+            alias=alias,
+            title=title,
+            description=description,
+            gt=gt,
+            ge=ge,
+            lt=lt,
+            le=le,
+            min_length=min_length,
+            max_length=max_length,
+            regex=regex,
+            **extra,
+        )
+
+
+class Body(Schema):
+    def __init__(
+        self,
+        default,
+        *,
+        sub_key=False,
+        alias: str = None,
+        title: str = None,
+        description: str = None,
+        gt: float = None,
+        ge: float = None,
+        lt: float = None,
+        le: float = None,
+        min_length: int = None,
+        max_length: int = None,
+        regex: str = None,
+        **extra: object,
+    ):
+        self.sub_key = sub_key
+        super().__init__(
+            default,
+            alias=alias,
+            title=title,
+            description=description,
+            gt=gt,
+            ge=ge,
+            lt=lt,
+            le=le,
+            min_length=min_length,
+            max_length=max_length,
+            regex=regex,
+            **extra,
+        )
+
+
+class Depends:
+    def __init__(self, dependency = None):
+        self.dependency = dependency
+
+
+class Security:
+    def __init__(self, security_scheme = None, scopes: Sequence[str] = None):
+        self.security_scheme = security_scheme
+        self.scopes = scopes
+
diff --git a/fastapi/pydantic_utils.py b/fastapi/pydantic_utils.py
new file mode 100644 (file)
index 0000000..8fc6589
--- /dev/null
@@ -0,0 +1,33 @@
+from types import GeneratorType
+from typing import Set
+from pydantic import BaseModel
+from enum import Enum
+from pydantic.json import pydantic_encoder
+
+
+def jsonable_encoder(
+    obj, include: Set[str] = None, exclude: Set[str] = set(), by_alias: bool = False, include_none=True,
+):
+    if isinstance(obj, BaseModel):
+        return jsonable_encoder(
+            obj.dict(include=include, exclude=exclude, by_alias=by_alias), include_none=include_none
+        )
+    elif isinstance(obj, Enum):
+        return obj.value
+    if isinstance(obj, (str, int, float, type(None))):
+        return obj
+    if isinstance(obj, dict):
+        return {
+            jsonable_encoder(
+                key, by_alias=by_alias, include_none=include_none,
+            ): jsonable_encoder(
+                value, by_alias=by_alias, include_none=include_none,
+            )
+            for key, value in obj.items() if value is not None or include_none
+        }
+    if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
+        return [
+            jsonable_encoder(item, include=include, exclude=exclude, by_alias=by_alias, include_none=include_none)
+            for item in obj
+        ]
+    return pydantic_encoder(obj)
diff --git a/fastapi/routing.py b/fastapi/routing.py
new file mode 100644 (file)
index 0000000..8c95b33
--- /dev/null
@@ -0,0 +1,785 @@
+import asyncio
+import inspect
+import re
+import typing
+from copy import deepcopy
+
+from starlette import routing
+from starlette.routing import get_name, request_response
+from starlette.requests import Request
+from starlette.responses import Response, JSONResponse
+from starlette.concurrency import run_in_threadpool
+from starlette.exceptions import HTTPException
+from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
+
+
+from pydantic.fields import Field, Required
+from pydantic.schema import get_annotation_from_schema
+from pydantic import BaseConfig, BaseModel, create_model, Schema
+from pydantic.error_wrappers import ErrorWrapper, ValidationError
+from pydantic.errors import MissingError
+from pydantic.utils import lenient_issubclass
+from .pydantic_utils import jsonable_encoder
+
+from fastapi import params
+from fastapi.security.base import SecurityBase
+
+
+param_supported_types = (str, int, float, bool)
+
+
+class Dependant:
+    def __init__(
+        self,
+        *,
+        path_params: typing.List[Field] = None,
+        query_params: typing.List[Field] = None,
+        header_params: typing.List[Field] = None,
+        cookie_params: typing.List[Field] = None,
+        body_params: typing.List[Field] = None,
+        dependencies: typing.List["Dependant"] = None,
+        security_schemes: typing.List[Field] = None,
+        name: str = None,
+        call: typing.Callable = None,
+        request_param_name: str = None,
+    ) -> None:
+        self.path_params: typing.List[Field] = path_params or []
+        self.query_params: typing.List[Field] = query_params or []
+        self.header_params: typing.List[Field] = header_params or []
+        self.cookie_params: typing.List[Field] = cookie_params or []
+        self.body_params: typing.List[Field] = body_params or []
+        self.dependencies: typing.List[Dependant] = dependencies or []
+        self.security_schemes: typing.List[Field] = security_schemes or []
+        self.request_param_name = request_param_name
+        self.name = name
+        self.call: typing.Callable = call
+
+
+def request_params_to_args(
+    required_params: typing.List[Field], received_params: typing.Dict[str, typing.Any]
+) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]:
+    values = {}
+    errors = []
+    for field in required_params:
+        value = received_params.get(field.alias)
+        if value is None:
+            if field.required:
+                errors.append(
+                    ErrorWrapper(MissingError(), loc=field.alias, config=BaseConfig)
+                )
+            else:
+                values[field.name] = deepcopy(field.default)
+            continue
+        v_, errors_ = field.validate(
+            value, values, loc=(field.schema.in_.value, field.alias)
+        )
+        if isinstance(errors_, ErrorWrapper):
+            errors_: ErrorWrapper
+            errors.append(errors_)
+        elif isinstance(errors_, list):
+            errors.extend(errors_)
+        else:
+            values[field.name] = v_
+    return values, errors
+
+
+def request_body_to_args(
+    required_params: typing.List[Field], received_body: typing.Dict[str, typing.Any]
+) -> typing.Tuple[typing.Dict[str, typing.Any], typing.List[ErrorWrapper]]:
+    values = {}
+    errors = []
+    if required_params:
+        field = required_params[0]
+        sub_key = getattr(field.schema, "sub_key", None)
+        if len(required_params) == 1 and not sub_key:
+            received_body = {field.alias: received_body}
+        for field in required_params:
+            value = received_body.get(field.alias)
+            if value is None:
+                if field.required:
+                    errors.append(
+                        ErrorWrapper(
+                            MissingError(), loc=("body", field.alias), config=BaseConfig
+                        )
+                    )
+                else:
+                    values[field.name] = deepcopy(field.default)
+                continue
+
+            v_, errors_ = field.validate(value, values, loc=("body", field.alias))
+            if isinstance(errors_, ErrorWrapper):
+                errors_: ErrorWrapper
+                errors.append(errors_)
+            elif isinstance(errors_, list):
+                errors.extend(errors_)
+            else:
+                values[field.name] = v_
+    return values, errors
+
+
+def add_param_to_fields(
+    *,
+    param: inspect.Parameter,
+    dependant: Dependant,
+    default_schema=params.Param,
+    force_type: params.ParamTypes = None,
+):
+    default_value = Required
+    if not param.default == param.empty:
+        default_value = param.default
+    if isinstance(default_value, params.Param):
+        schema = default_value
+        default_value = schema.default
+        if schema.in_ is None:
+            schema.in_ = default_schema.in_
+        if force_type:
+            schema.in_ = force_type
+    else:
+        schema = default_schema(default_value)
+    required = default_value == Required
+    annotation = typing.Any
+    if not param.annotation == param.empty:
+        annotation = param.annotation
+    annotation = get_annotation_from_schema(annotation, schema)
+    Config = BaseConfig
+    field = Field(
+        name=param.name,
+        type_=annotation,
+        default=None if required else default_value,
+        alias=schema.alias or param.name,
+        required=required,
+        model_config=Config,
+        class_validators=[],
+        schema=schema,
+    )
+    if schema.in_ == params.ParamTypes.path:
+        dependant.path_params.append(field)
+    elif schema.in_ == params.ParamTypes.query:
+        dependant.query_params.append(field)
+    elif schema.in_ == params.ParamTypes.header:
+        dependant.header_params.append(field)
+    else:
+        assert (
+            schema.in_ == params.ParamTypes.cookie
+        ), f"non-body parameters must be in path, query, header or cookie: {param.name}"
+        dependant.cookie_params.append(field)
+
+
+def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant):
+    default_value = Required
+    if not param.default == param.empty:
+        default_value = param.default
+    if isinstance(default_value, Schema):
+        schema = default_value
+        default_value = schema.default
+    else:
+        schema = Schema(default_value)
+    required = default_value == Required
+    annotation = get_annotation_from_schema(param.annotation, schema)
+    field = Field(
+        name=param.name,
+        type_=annotation,
+        default=None if required else default_value,
+        alias=schema.alias or param.name,
+        required=required,
+        model_config=BaseConfig,
+        class_validators=[],
+        schema=schema,
+    )
+    dependant.body_params.append(field)
+
+
+def get_sub_dependant(
+    *, param: inspect.Parameter, path: str
+):
+    depends: params.Depends = param.default
+    if depends.dependency:
+        dependency = depends.dependency
+    else:
+        dependency = param.annotation
+    assert callable(dependency)
+    sub_dependant = get_dependant(path=path, call=dependency, name=param.name)
+    if isinstance(dependency, SecurityBase):
+        sub_dependant.security_schemes.append(dependency)
+    return sub_dependant
+
+
+def get_flat_dependant(dependant: Dependant):
+    flat_dependant = Dependant(
+        path_params=dependant.path_params.copy(),
+        query_params=dependant.query_params.copy(),
+        header_params=dependant.header_params.copy(),
+        cookie_params=dependant.cookie_params.copy(),
+        body_params=dependant.body_params.copy(),
+        security_schemes=dependant.security_schemes.copy(),
+    )
+    for sub_dependant in dependant.dependencies:
+        if sub_dependant is dependant:
+            raise ValueError("recursion", dependant.dependencies)
+        flat_sub = get_flat_dependant(sub_dependant)
+        flat_dependant.path_params.extend(flat_sub.path_params)
+        flat_dependant.query_params.extend(flat_sub.query_params)
+        flat_dependant.header_params.extend(flat_sub.header_params)
+        flat_dependant.cookie_params.extend(flat_sub.cookie_params)
+        flat_dependant.body_params.extend(flat_sub.body_params)
+        flat_dependant.security_schemes.extend(flat_sub.security_schemes)
+    return flat_dependant
+
+
+def get_path_param_names(path: str):
+    return {item.strip("{}") for item in re.findall("{[^}]*}", path)}
+
+
+def get_dependant(*, path: str, call: typing.Callable, name: str = None):
+    path_param_names = get_path_param_names(path)
+    endpoint_signature = inspect.signature(call)
+    signature_params = endpoint_signature.parameters
+    dependant = Dependant(call=call, name=name)
+    for param_name in signature_params:
+        param = signature_params[param_name]
+        if isinstance(param.default, params.Depends):
+            sub_dependant = get_sub_dependant(param=param, path=path)
+            dependant.dependencies.append(sub_dependant)
+    for param_name in signature_params:
+        param = signature_params[param_name]
+        if (
+            (param.default == param.empty) or isinstance(param.default, params.Path)
+        ) and (param_name in path_param_names):
+            assert lenient_issubclass(
+                param.annotation, param_supported_types
+            ), f"Path params must be of type str, int, float or boot: {param}"
+            param = signature_params[param_name]
+            add_param_to_fields(
+                param=param,
+                dependant=dependant,
+                default_schema=params.Path,
+                force_type=params.ParamTypes.path,
+            )
+        elif (param.default == param.empty or param.default is None) and (
+            param.annotation == param.empty
+            or lenient_issubclass(param.annotation, param_supported_types)
+        ):
+            add_param_to_fields(
+                param=param, dependant=dependant, default_schema=params.Query
+            )
+        elif isinstance(param.default, params.Param):
+            if param.annotation != param.empty:
+                assert lenient_issubclass(
+                    param.annotation, param_supported_types
+                ), f"Parameters for Path, Query, Header and Cookies must be of type str, int, float or bool: {param}"
+            add_param_to_fields(
+                param=param, dependant=dependant, default_schema=params.Query
+            )
+        elif lenient_issubclass(param.annotation, Request):
+            dependant.request_param_name = param_name
+        elif not isinstance(param.default, params.Depends):
+            add_param_to_body_fields(param=param, dependant=dependant)
+    return dependant
+
+
+def is_coroutine_callable(call: typing.Callable):
+    if inspect.isfunction(call):
+        return asyncio.iscoroutinefunction(call)
+    elif inspect.isclass(call):
+        return False
+    else:
+        call = getattr(call, "__call__", None)
+        if not call:
+            return False
+        else:
+            return asyncio.iscoroutinefunction(call)
+
+
+async def solve_dependencies(*, request: Request, dependant: Dependant):
+    values = {}
+    errors = []
+    for sub_dependant in dependant.dependencies:
+        sub_values, sub_errors = await solve_dependencies(
+            request=request, dependant=sub_dependant
+        )
+        if sub_errors:
+            return {}, errors
+        if is_coroutine_callable(sub_dependant.call):
+            solved = await sub_dependant.call(**sub_values)
+        else:
+            solved = await run_in_threadpool(sub_dependant.call, **sub_values)
+        values[sub_dependant.name] = solved
+    path_values, path_errors = request_params_to_args(
+        dependant.path_params, request.path_params
+    )
+    query_values, query_errors = request_params_to_args(
+        dependant.query_params, request.query_params
+    )
+    header_values, header_errors = request_params_to_args(
+        dependant.header_params, request.headers
+    )
+    cookie_values, cookie_errors = request_params_to_args(
+        dependant.cookie_params, request.cookies
+    )
+    values.update(path_values)
+    values.update(query_values)
+    values.update(header_values)
+    values.update(cookie_values)
+    errors = path_errors + query_errors + header_errors + cookie_errors
+    if dependant.body_params:
+        body = await request.json()
+        body_values, body_errors = request_body_to_args(dependant.body_params, body)
+        values.update(body_values)
+        errors.extend(body_errors)
+    if dependant.request_param_name:
+        values[dependant.request_param_name] = request
+    return values, errors
+
+
+def get_app(dependant: Dependant):
+    is_coroutine = asyncio.iscoroutinefunction(dependant.call)
+
+    async def app(request: Request) -> Response:
+        values, errors = await solve_dependencies(request=request, dependant=dependant)
+        if errors:
+            errors_out = ValidationError(errors)
+            raise HTTPException(
+                status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
+            )
+        else:
+            if is_coroutine:
+                raw_response = await dependant.call(**values)
+            else:
+                raw_response = await run_in_threadpool(dependant.call, **values)
+            if isinstance(raw_response, Response):
+                return raw_response
+            else:
+                return JSONResponse(content=jsonable_encoder(raw_response))
+    return app
+
+
+def get_openapi_params(dependant: Dependant):
+    flat_dependant = get_flat_dependant(dependant)
+    return (
+        flat_dependant.path_params
+        + flat_dependant.query_params
+        + flat_dependant.header_params
+        + flat_dependant.cookie_params
+    )
+
+
+class APIRoute(routing.Route):
+    def __init__(
+        self,
+        path: str,
+        endpoint: typing.Callable,
+        *,
+        methods: typing.List[str] = None,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ) -> None:
+        # TODO define how to read and provide security params, and how to have them globally too
+        # TODO implement dependencies and injection
+        # TODO refactor code structure
+        # TODO create testing
+        # TODO testing coverage
+        assert path.startswith("/"), "Routed paths must always start '/'"
+        self.path = path
+        self.endpoint = endpoint
+        self.name = get_name(endpoint) if name is None else name
+        self.include_in_schema = include_in_schema
+        self.tags = tags
+        self.summary = summary
+        self.description = description
+        self.operation_id = operation_id
+        self.deprecated = deprecated
+        self.request_body: typing.Union[BaseModel, Field, None] = None
+        self.response_description = response_description
+        self.response_code = response_code
+        self.response_wrapper = response_wrapper
+        self.response_field = None
+        if response_type:
+            assert lenient_issubclass(
+                response_wrapper, JSONResponse
+            ), "To declare a type the response must be a JSON response"
+            self.response_type = response_type
+            response_name = "Response_" + self.name
+            self.response_field = Field(
+                name=response_name,
+                type_=self.response_type,
+                class_validators=[],
+                default=None,
+                required=False,
+                model_config=BaseConfig(),
+                schema=Schema(None),
+            )
+        else:
+            self.response_type = None
+        if methods is None:
+            methods = ["GET"]
+        self.methods = methods
+        self.path_regex, self.path_format, self.param_convertors = self.compile_path(
+            path
+        )
+        assert inspect.isfunction(endpoint) or inspect.ismethod(
+            endpoint
+        ), f"An endpoint must be a function or method"
+
+        self.dependant = get_dependant(path=path, call=self.endpoint)
+        # flat_dependant = get_flat_dependant(self.dependant)
+        # path_param_names = get_path_param_names(path)
+        # for path_param in path_param_names:
+        #     assert path_param in {
+        #         f.alias for f in flat_dependant.path_params
+        #     }, f"Path parameter must be defined as a function parameter or be defined by a dependency: {path_param}"
+
+        if self.dependant.body_params:
+            first_param = self.dependant.body_params[0]
+            sub_key = getattr(first_param.schema, "sub_key", None)
+            if len(self.dependant.body_params) == 1 and not sub_key:
+                self.request_body = first_param
+            else:
+                model_name = "Body_" + self.name
+                BodyModel = create_model(model_name)
+                for f in self.dependant.body_params:
+                    BodyModel.__fields__[f.name] = f
+                required = any(True for f in self.dependant.body_params if f.required)
+                field = Field(
+                    name="body",
+                    type_=BodyModel,
+                    default=None,
+                    required=required,
+                    model_config=BaseConfig,
+                    class_validators=[],
+                    alias="body",
+                    schema=Schema(None),
+                )
+                self.request_body = field
+
+        self.app = request_response(get_app(dependant=self.dependant))
+
+
+class APIRouter(routing.Router):
+    def add_api_route(
+        self,
+        path: str,
+        endpoint: typing.Callable,
+        methods: typing.List[str] = None,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ) -> None:
+        route = APIRoute(
+            path,
+            endpoint=endpoint,
+            methods=methods,
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+        self.routes.append(route)
+
+    def api_route(
+        self,
+        path: str,
+        methods: typing.List[str] = None,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ) -> typing.Callable:
+        def decorator(func: typing.Callable) -> typing.Callable:
+            self.add_api_route(
+                path,
+                func,
+                methods=methods,
+                name=name,
+                include_in_schema=include_in_schema,
+                tags=tags,
+                summary=summary,
+                description=description,
+                operation_id=operation_id,
+                deprecated=deprecated,
+                response_type=response_type,
+                response_description=response_description,
+                response_code=response_code,
+                response_wrapper=response_wrapper,
+            )
+            return func
+
+        return decorator
+
+    def get(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.api_route(
+            path=path,
+            methods=["GET"],
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def put(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.api_route(
+            path=path,
+            methods=["PUT"],
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def post(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.api_route(
+            path=path,
+            methods=["POST"],
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def delete(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.api_route(
+            path=path,
+            methods=["DELETE"],
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def options(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.api_route(
+            path=path,
+            methods=["OPTIONS"],
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def head(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.api_route(
+            path=path,
+            methods=["HEAD"],
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def patch(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.api_route(
+            path=path,
+            methods=["PATCH"],
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
+
+    def trace(
+        self,
+        path: str,
+        name: str = None,
+        include_in_schema: bool = True,
+        tags: typing.List[str] = [],
+        summary: str = None,
+        description: str = None,
+        operation_id: str = None,
+        deprecated: bool = None,
+        response_type: typing.Type = None,
+        response_description: str = "Successful Response",
+        response_code=200,
+        response_wrapper=JSONResponse,
+    ):
+        return self.api_route(
+            path=path,
+            methods=["TRACE"],
+            name=name,
+            include_in_schema=include_in_schema,
+            tags=tags,
+            summary=summary,
+            description=description,
+            operation_id=operation_id,
+            deprecated=deprecated,
+            response_type=response_type,
+            response_description=response_description,
+            response_code=response_code,
+            response_wrapper=response_wrapper,
+        )
diff --git a/fastapi/security/__init__.py b/fastapi/security/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py
new file mode 100644 (file)
index 0000000..4b6766e
--- /dev/null
@@ -0,0 +1,40 @@
+from starlette.requests import Request
+
+from pydantic import Schema
+from enum import Enum
+from .base import SecurityBase, Types
+
+__all__ = ["APIKeyIn", "APIKeyBase", "APIKeyQuery", "APIKeyHeader", "APIKeyCookie"]
+
+
+class APIKeyIn(Enum):
+    query = "query"
+    header = "header"
+    cookie = "cookie"
+
+
+class APIKeyBase(SecurityBase):
+    type_ = Schema(Types.apiKey, alias="type")
+    in_: str = Schema(..., alias="in")
+    name: str
+
+
+class APIKeyQuery(APIKeyBase):
+    in_ = Schema(APIKeyIn.query, alias="in")
+    
+    async def __call__(self, requests: Request):
+        return requests.query_params.get(self.name)
+
+
+class APIKeyHeader(APIKeyBase):
+    in_ = Schema(APIKeyIn.header, alias="in")
+
+    async def __call__(self, requests: Request):
+        return requests.headers.get(self.name)
+
+
+class APIKeyCookie(APIKeyBase):
+    in_ = Schema(APIKeyIn.cookie, alias="in")
+
+    async def __call__(self, requests: Request):
+        return requests.cookies.get(self.name)
diff --git a/fastapi/security/base.py b/fastapi/security/base.py
new file mode 100644 (file)
index 0000000..37433ff
--- /dev/null
@@ -0,0 +1,17 @@
+from enum import Enum
+from pydantic import BaseModel, Schema
+
+__all__ = ["Types", "SecurityBase"]
+
+
+class Types(Enum):
+    apiKey = "apiKey"
+    http = "http"
+    oauth2 = "oauth2"
+    openIdConnect = "openIdConnect"
+
+
+class SecurityBase(BaseModel):
+    scheme_name: str = None
+    type_: Types = Schema(..., alias="type")
+    description: str = None
diff --git a/fastapi/security/http.py b/fastapi/security/http.py
new file mode 100644 (file)
index 0000000..aaaf866
--- /dev/null
@@ -0,0 +1,27 @@
+
+from starlette.requests import Request
+from pydantic import Schema
+from .base import SecurityBase, Types
+
+__all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"]
+
+
+class HTTPBase(SecurityBase):
+    type_ = Schema(Types.http, alias="type")
+    scheme: str
+
+    async def __call__(self, request: Request):
+        return request.headers.get("Authorization")
+
+
+class HTTPBasic(HTTPBase):
+    scheme = "basic"
+
+
+class HTTPBearer(HTTPBase):
+    scheme = "bearer"
+    bearerFormat: str = None
+
+
+class HTTPDigest(HTTPBase):
+    scheme = "digest"
diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py
new file mode 100644 (file)
index 0000000..a6607ef
--- /dev/null
@@ -0,0 +1,45 @@
+from typing import Dict
+
+from pydantic import BaseModel, Schema
+
+from starlette.requests import Request
+from .base import SecurityBase, Types
+
+# __all__ = ["HTTPBase", "HTTPBasic", "HTTPBearer", "HTTPDigest"]
+
+
+class OAuthFlow(BaseModel):
+    refreshUrl: str = None
+    scopes: Dict[str, str] = {}
+
+
+class OAuthFlowImplicit(OAuthFlow):
+    authorizationUrl: str
+
+
+class OAuthFlowPassword(OAuthFlow):
+    tokenUrl: str
+
+
+class OAuthFlowClientCredentials(OAuthFlow):
+    tokenUrl: str
+
+
+class OAuthFlowAuthorizationCode(OAuthFlow):
+    authorizationUrl: str
+    tokenUrl: str
+
+
+class OAuthFlows(BaseModel):
+    implicit: OAuthFlowImplicit = None
+    password: OAuthFlowPassword = None
+    clientCredentials: OAuthFlowClientCredentials = None
+    authorizationCode: OAuthFlowAuthorizationCode = None
+
+
+class OAuth2(SecurityBase):
+    type_ = Schema(Types.oauth2, alias="type")
+    flows: OAuthFlows
+
+    async def __call__(self, request: Request):
+        return request.headers.get("Authorization")
diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py
new file mode 100644 (file)
index 0000000..2e7791a
--- /dev/null
@@ -0,0 +1,10 @@
+from starlette.requests import Request
+
+from .base import SecurityBase, Types
+
+class OpenIdConnect(SecurityBase):
+    type_ = Types.openIdConnect
+    openIdConnectUrl: str
+
+    async def __call__(self, request: Request):
+        return request.headers.get("Authorization")