]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:sparkles: Refactor param extraction using Pydantic Field (#278)
authorSebastián Ramírez <tiangolo@gmail.com>
Thu, 30 May 2019 13:40:43 +0000 (17:40 +0400)
committerGitHub <noreply@github.com>
Thu, 30 May 2019 13:40:43 +0000 (17:40 +0400)
* :sparkles: Refactor parameter dependency using Pydantic Field

* :arrow_up: Upgrade required Pydantic version with latest Shape values

* :sparkles: Add tutorials and code for using Enum and Optional

* :white_check_mark: Add tests for tutorials with new types and extra cases

* :recycle: Format, clean, and add annotations to dependencies.utils

* :memo: Update tutorial for query parameters with list defaults

* :white_check_mark: Add tests for query param with list default

14 files changed:
docs/img/tutorial/path-params/image03.png [new file with mode: 0644]
docs/src/path_params/tutorial005.py [new file with mode: 0644]
docs/src/query_params/tutorial007.py [new file with mode: 0644]
docs/src/query_params_str_validations/tutorial012.py [new file with mode: 0644]
docs/tutorial/path-params.md
docs/tutorial/query-params-str-validations.md
docs/tutorial/query-params.md
fastapi/dependencies/utils.py
fastapi/openapi/utils.py
pyproject.toml
tests/test_invalid_sequence_param.py [new file with mode: 0644]
tests/test_tutorial/test_path_params/test_tutorial005.py [new file with mode: 0644]
tests/test_tutorial/test_query_params/test_tutorial007.py [new file with mode: 0644]
tests/test_tutorial/test_query_params_str_validations/test_tutorial012.py [new file with mode: 0644]

diff --git a/docs/img/tutorial/path-params/image03.png b/docs/img/tutorial/path-params/image03.png
new file mode 100644 (file)
index 0000000..d08645d
Binary files /dev/null and b/docs/img/tutorial/path-params/image03.png differ
diff --git a/docs/src/path_params/tutorial005.py b/docs/src/path_params/tutorial005.py
new file mode 100644 (file)
index 0000000..d4f24bc
--- /dev/null
@@ -0,0 +1,21 @@
+from enum import Enum
+
+from fastapi import FastAPI
+
+
+class ModelName(Enum):
+    alexnet = "alexnet"
+    resnet = "resnet"
+    lenet = "lenet"
+
+
+app = FastAPI()
+
+
+@app.get("/model/{model_name}")
+async def get_model(model_name: ModelName):
+    if model_name == ModelName.alexnet:
+        return {"model_name": model_name, "message": "Deep Learning FTW!"}
+    if model_name.value == "lenet":
+        return {"model_name": model_name, "message": "LeCNN all the images"}
+    return {"model_name": model_name, "message": "Have some residuals"}
diff --git a/docs/src/query_params/tutorial007.py b/docs/src/query_params/tutorial007.py
new file mode 100644 (file)
index 0000000..8ef5b30
--- /dev/null
@@ -0,0 +1,11 @@
+from typing import Optional
+
+from fastapi import FastAPI
+
+app = FastAPI()
+
+
+@app.get("/items/{item_id}")
+async def read_user_item(item_id: str, limit: Optional[int] = None):
+    item = {"item_id": item_id, "limit": limit}
+    return item
diff --git a/docs/src/query_params_str_validations/tutorial012.py b/docs/src/query_params_str_validations/tutorial012.py
new file mode 100644 (file)
index 0000000..7ea9f01
--- /dev/null
@@ -0,0 +1,11 @@
+from typing import List
+
+from fastapi import FastAPI, Query
+
+app = FastAPI()
+
+
+@app.get("/items/")
+async def read_items(q: List[str] = Query(["foo", "bar"])):
+    query_items = {"q": q}
+    return query_items
index b7cf9d4df5f04869d6a8b39a1bc9900b8e4b050c..96e29366e1455c4e74f4f525ce53771687ec8747 100644 (file)
@@ -35,7 +35,7 @@ If you run this example and open your browser at <a href="http://127.0.0.1:8000/
 
 !!! check
     Notice that the value your function received (and returned) is `3`, as a Python `int`, not a string `"3"`.
-    
+
     So, with that type declaration, **FastAPI** gives you automatic request <abbr title="converting the string that comes from an HTTP request into Python data">"parsing"</abbr>.
 
 ## Data validation
@@ -61,12 +61,11 @@ because the path parameter `item_id` had a value of `"foo"`, which is not an `in
 
 The same error would appear if you provided a `float` instead of an int, as in: <a href="http://127.0.0.1:8000/items/4.2" target="_blank">http://127.0.0.1:8000/items/4.2</a>
 
-
 !!! check
     So, with the same Python type declaration, **FastAPI** gives you data validation.
 
-    Notice that the error also clearly states exactly the point where the validation didn't pass. 
-    
+    Notice that the error also clearly states exactly the point where the validation didn't pass.
+
     This is incredibly helpful while developing and debugging code that interacts with your API.
 
 ## Documentation
@@ -96,8 +95,7 @@ All the data validation is performed under the hood by <a href="https://pydantic
 
 You can use the same type declarations with `str`, `float`, `bool` and many other complex data types.
 
-These are explored in the next chapters of the tutorial.
-
+Several of these are explored in the next chapters of the tutorial.
 
 ## Order matters
 
@@ -115,6 +113,73 @@ Because path operations are evaluated in order, you need to make sure that the p
 
 Otherwise, the path for `/users/{user_id}` would match also for `/users/me`, "thinking" that it's receiving a parameter `user_id` with a value of `"me"`.
 
+## Predefined values
+
+If you have a *path operation* that receives a *path parameter*, but you want the possible valid *path parameter* values to be predefined, you can use a standard Python <abbr title="Enumeration">`Enum`</abbr>.
+
+### Create an `Enum` class
+
+Import `Enum` and create a sub-class that inherits from it.
+
+And create class attributes with fixed values, those fixed values will be the available valid values:
+
+```Python hl_lines="1 6 7 8 9"
+{!./src/path_params/tutorial005.py!}
+```
+
+!!! info
+    <a href="https://docs.python.org/3/library/enum.html" target="_blank">Enumerations (or enums) are available in Python</a> since version 3.4.
+
+!!! tip
+    If you are wondering, "AlexNet", "ResNet", and "LeNet" are just names of Machine Learning <abbr title="Technically, Deep Learning model architectures">models</abbr>.
+
+### Declare a *path parameter*
+
+Then create a *path parameter* with a type annotation using the enum class you created (`ModelName`):
+
+```Python hl_lines="16"
+{!./src/path_params/tutorial005.py!}
+```
+
+### Check the docs
+
+Because the available values for the *path parameter* are specified, the interactive docs can show them nicely:
+
+<img src="/img/tutorial/path-params/image03.png">
+
+### Working with Python *enumerations*
+
+The value of the *path parameter* will be an *enumeration member*.
+
+#### Compare *enumeration members*
+
+You can compare it with the *enumeration member* in your created enum `ModelName`:
+
+```Python hl_lines="17"
+{!./src/path_params/tutorial005.py!}
+```
+
+#### Get the *enumeration value*
+
+You can get the actual value (a `str` in this case) using `model_name.value`, or in general, `your_enum_member.value`:
+
+```Python hl_lines="19"
+{!./src/path_params/tutorial005.py!}
+```
+
+!!! tip
+    You could also access the value `"lenet"` with `ModelName.lenet.value`.
+
+#### Return *enumeration members*
+
+You can return *enum members* from your *path operation*, even nested in a JSON body (e.g. a `dict`).
+
+They will be converted to their corresponding values before returning them to the client:
+
+```Python hl_lines="18 20 21"
+{!./src/path_params/tutorial005.py!}
+```
+
 ## Path parameters containing paths
 
 Let's say you have a *path operation* with a path `/files/{file_path}`.
index a82018437b446e74e493f6d6de15a01628829b4f..4258a71fddef34c666644835990106a58e162f0e 100644 (file)
@@ -12,7 +12,6 @@ The query parameter `q` is of type `str`, and by default is `None`, so it is opt
 
 We are going to enforce that even though `q` is optional, whenever it is provided, it **doesn't exceed a length of 50 characters**.
 
-
 ### Import `Query`
 
 To achieve that, first import `Query` from `fastapi`:
@@ -29,7 +28,7 @@ And now use it as the default value of your parameter, setting the parameter `ma
 {!./src/query_params_str_validations/tutorial002.py!}
 ```
 
-As we have to replace the default value `None` with `Query(None)`, the first parameter to `Query` serves the same purpose of defining that default value. 
+As we have to replace the default value `None` with `Query(None)`, the first parameter to `Query` serves the same purpose of defining that default value.
 
 So:
 
@@ -41,7 +40,7 @@ q: str = Query(None)
 
 ```Python
 q: str = None
-``` 
+```
 
 But it declares it explicitly as being a query parameter.
 
@@ -53,7 +52,6 @@ q: str = Query(None, max_length=50)
 
 This will validate the data, show a clear error when the data is not valid, and document the parameter in the OpenAPI schema path operation.
 
-
 ## Add more validations
 
 You can also add a parameter `min_length`:
@@ -119,7 +117,7 @@ So, when you need to declare a value as required while using `Query`, you can us
 {!./src/query_params_str_validations/tutorial006.py!}
 ```
 
-!!! info 
+!!! info
     If you hadn't seen that `...` before: it is a a special single value, it is <a href="https://docs.python.org/3/library/constants.html#Ellipsis" target="_blank">part of Python and is called "Ellipsis"</a>.
 
 This will let **FastAPI** know that this parameter is required.
@@ -156,11 +154,35 @@ So, the response to that URL would be:
 !!! tip
     To declare a query parameter with a type of `list`, like in the example above, you need to explicitly use `Query`, otherwise it would be interpreted as a request body.
 
-
 The interactive API docs will update accordingly, to allow multiple values:
 
 <img src="/img/tutorial/query-params-str-validations/image02.png">
 
+### Query parameter list / multiple values with defaults
+
+And you can also define a default `list` of values if none are provided:
+
+```Python hl_lines="9"
+{!./src/query_params_str_validations/tutorial012.py!}
+```
+
+If you go to:
+
+```
+http://localhost:8000/items/
+```
+
+the default of `q` will be: `["foo", "bar"]` and your response will be:
+
+```JSON
+{
+  "q": [
+    "foo",
+    "bar"
+  ]
+}
+```
+
 ## Declare more metadata
 
 You can add more information about the parameter.
index 54a71f36d67efe5ba2c6d736dee2a7801b415497..85a69205d4d06a7df5a598c3bd5447fbe3a00d48 100644 (file)
@@ -186,3 +186,39 @@ In this case, there are 3 query parameters:
 * `needy`, a required `str`.
 * `skip`, an `int` with a default value of `0`.
 * `limit`, an optional `int`.
+
+!!! tip
+    You could also use `Enum`s <a href="https://fastapi.tiangolo.com/tutorial/path-params/#predefined-values" target="_blank">the same way as with *path parameters*</a>.
+
+## Optional type declarations
+
+!!! warning
+    This might be an advanced use case.
+
+    You might want to skip it.
+
+If you are using `mypy` it could complain with type declarations like:
+
+```Python
+limit: int = None
+```
+
+With an error like:
+
+```
+Incompatible types in assignment (expression has type "None", variable has type "int")
+```
+
+In those cases you can use `Optional` to tell `mypy` that the value could be `None`, like:
+
+```Python
+from typing import Optional
+
+limit: Optional[int] = None
+```
+
+In a *path operation* that could look like:
+
+```Python hl_lines="9"
+{!./src/query_params/tutorial007.py!}
+```
index 194187f28ca06898cd828abdfb93e47bba22cdcf..2596d5754142a194d3d67814cc3bebf0cd5c371c 100644 (file)
@@ -1,8 +1,6 @@
 import asyncio
 import inspect
 from copy import deepcopy
-from datetime import date, datetime, time, timedelta
-from decimal import Decimal
 from typing import (
     Any,
     Callable,
@@ -14,8 +12,8 @@ from typing import (
     Tuple,
     Type,
     Union,
+    cast,
 )
-from uuid import UUID
 
 from fastapi import params
 from fastapi.dependencies.models import Dependant, SecurityRequirement
@@ -23,7 +21,7 @@ from fastapi.security.base import SecurityBase
 from fastapi.security.oauth2 import OAuth2, SecurityScopes
 from fastapi.security.open_id_connect_url import OpenIdConnect
 from fastapi.utils import get_path_param_names
-from pydantic import BaseConfig, Schema, create_model
+from pydantic import BaseConfig, BaseModel, Schema, create_model
 from pydantic.error_wrappers import ErrorWrapper
 from pydantic.errors import MissingError
 from pydantic.fields import Field, Required, Shape
@@ -35,22 +33,21 @@ from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
 from starlette.requests import Request
 from starlette.websockets import WebSocket
 
-param_supported_types = (
-    str,
-    int,
-    float,
-    bool,
-    UUID,
-    date,
-    datetime,
-    time,
-    timedelta,
-    Decimal,
-)
-
-sequence_shapes = {Shape.LIST, Shape.SET, Shape.TUPLE}
+sequence_shapes = {
+    Shape.LIST,
+    Shape.SET,
+    Shape.TUPLE,
+    Shape.SEQUENCE,
+    Shape.TUPLE_ELLIPS,
+}
 sequence_types = (list, set, tuple)
-sequence_shape_to_type = {Shape.LIST: list, Shape.SET: set, Shape.TUPLE: tuple}
+sequence_shape_to_type = {
+    Shape.LIST: list,
+    Shape.SET: set,
+    Shape.TUPLE: tuple,
+    Shape.SEQUENCE: list,
+    Shape.TUPLE_ELLIPS: list,
+}
 
 
 def get_param_sub_dependant(
@@ -126,6 +123,26 @@ def get_flat_dependant(dependant: Dependant) -> Dependant:
     return flat_dependant
 
 
+def is_scalar_field(field: Field) -> bool:
+    return (
+        field.shape == Shape.SINGLETON
+        and not lenient_issubclass(field.type_, BaseModel)
+        and not isinstance(field.schema, params.Body)
+    )
+
+
+def is_scalar_sequence_field(field: Field) -> bool:
+    if field.shape in sequence_shapes and not lenient_issubclass(
+        field.type_, BaseModel
+    ):
+        if field.sub_fields is not None:
+            for sub_field in field.sub_fields:
+                if not is_scalar_field(sub_field):
+                    return False
+        return True
+    return False
+
+
 def get_dependant(
     *, path: str, call: Callable, name: str = None, security_scopes: List[str] = None
 ) -> Dependant:
@@ -133,83 +150,78 @@ def get_dependant(
     endpoint_signature = inspect.signature(call)
     signature_params = endpoint_signature.parameters
     dependant = Dependant(call=call, name=name)
-    for param_name in signature_params:
-        param = signature_params[param_name]
+    for param_name, param in signature_params.items():
         if isinstance(param.default, params.Depends):
             sub_dependant = get_param_sub_dependant(
                 param=param, path=path, security_scopes=security_scopes
             )
             dependant.dependencies.append(sub_dependant)
-    for param_name in signature_params:
-        param = signature_params[param_name]
-        if (
-            (param.default == param.empty) or isinstance(param.default, params.Path)
-        ) and (param_name in path_param_names):
-            assert (
-                lenient_issubclass(param.annotation, param_supported_types)
-                or param.annotation == param.empty
+    for param_name, param in signature_params.items():
+        if isinstance(param.default, params.Depends):
+            continue
+        if add_non_field_param_to_dependency(param=param, dependant=dependant):
+            continue
+        param_field = get_param_field(param=param, default_schema=params.Query)
+        if param_name in path_param_names:
+            assert param.default == param.empty or isinstance(
+                param.default, params.Path
+            ), "Path params must have no defaults or use Path(...)"
+            assert is_scalar_field(
+                field=param_field
             ), f"Path params must be of one of the supported types"
-            add_param_to_fields(
+            param_field = get_param_field(
                 param=param,
-                dependant=dependant,
                 default_schema=params.Path,
                 force_type=params.ParamTypes.path,
             )
-        elif (
-            param.default == param.empty
-            or param.default is None
-            or isinstance(param.default, param_supported_types)
-        ) and (
-            param.annotation == param.empty
-            or lenient_issubclass(param.annotation, param_supported_types)
-        ):
-            add_param_to_fields(
-                param=param, dependant=dependant, default_schema=params.Query
-            )
-        elif isinstance(param.default, params.Param):
-            if param.annotation != param.empty:
-                origin = getattr(param.annotation, "__origin__", None)
-                param_all_types = param_supported_types + (list, tuple, set)
-                if isinstance(param.default, (params.Query, params.Header)):
-                    assert lenient_issubclass(
-                        param.annotation, param_all_types
-                    ) or lenient_issubclass(
-                        origin, param_all_types
-                    ), f"Parameters for Query and Header must be of type str, int, float, bool, list, tuple or set: {param}"
-                else:
-                    assert lenient_issubclass(
-                        param.annotation, param_supported_types
-                    ), f"Parameters for Path and Cookies must be of type str, int, float, bool: {param}"
-            add_param_to_fields(
-                param=param, dependant=dependant, default_schema=params.Query
-            )
-        elif lenient_issubclass(param.annotation, Request):
-            dependant.request_param_name = param_name
-        elif lenient_issubclass(param.annotation, WebSocket):
-            dependant.websocket_param_name = param_name
-        elif lenient_issubclass(param.annotation, BackgroundTasks):
-            dependant.background_tasks_param_name = param_name
-        elif lenient_issubclass(param.annotation, SecurityScopes):
-            dependant.security_scopes_param_name = param_name
-        elif not isinstance(param.default, params.Depends):
-            add_param_to_body_fields(param=param, dependant=dependant)
+            add_param_to_fields(field=param_field, dependant=dependant)
+        elif is_scalar_field(field=param_field):
+            add_param_to_fields(field=param_field, dependant=dependant)
+        elif isinstance(
+            param.default, (params.Query, params.Header)
+        ) and is_scalar_sequence_field(param_field):
+            add_param_to_fields(field=param_field, dependant=dependant)
+        else:
+            assert isinstance(
+                param_field.schema, params.Body
+            ), f"Param: {param_field.name} can only be a request body, using Body(...)"
+            dependant.body_params.append(param_field)
     return dependant
 
 
-def add_param_to_fields(
+def add_non_field_param_to_dependency(
+    *, param: inspect.Parameter, dependant: Dependant
+) -> Optional[bool]:
+    if lenient_issubclass(param.annotation, Request):
+        dependant.request_param_name = param.name
+        return True
+    elif lenient_issubclass(param.annotation, WebSocket):
+        dependant.websocket_param_name = param.name
+        return True
+    elif lenient_issubclass(param.annotation, BackgroundTasks):
+        dependant.background_tasks_param_name = param.name
+        return True
+    elif lenient_issubclass(param.annotation, SecurityScopes):
+        dependant.security_scopes_param_name = param.name
+        return True
+    return None
+
+
+def get_param_field(
     *,
     param: inspect.Parameter,
-    dependant: Dependant,
-    default_schema: Type[Schema] = params.Param,
+    default_schema: Type[params.Param] = params.Param,
     force_type: params.ParamTypes = None,
-) -> None:
+) -> Field:
     default_value = Required
+    had_schema = False
     if not param.default == param.empty:
         default_value = param.default
-    if isinstance(default_value, params.Param):
+    if isinstance(default_value, Schema):
+        had_schema = True
         schema = default_value
         default_value = schema.default
-        if getattr(schema, "in_", None) is None:
+        if isinstance(schema, params.Param) and getattr(schema, "in_", None) is None:
             schema.in_ = default_schema.in_
         if force_type:
             schema.in_ = force_type
@@ -234,43 +246,26 @@ def add_param_to_fields(
         class_validators={},
         schema=schema,
     )
-    if schema.in_ == params.ParamTypes.path:
+    if not had_schema and not is_scalar_field(field=field):
+        field.schema = params.Body(schema.default)
+    return field
+
+
+def add_param_to_fields(*, field: Field, dependant: Dependant) -> None:
+    field.schema = cast(params.Param, field.schema)
+    if field.schema.in_ == params.ParamTypes.path:
         dependant.path_params.append(field)
-    elif schema.in_ == params.ParamTypes.query:
+    elif field.schema.in_ == params.ParamTypes.query:
         dependant.query_params.append(field)
-    elif schema.in_ == params.ParamTypes.header:
+    elif field.schema.in_ == params.ParamTypes.header:
         dependant.header_params.append(field)
     else:
         assert (
-            schema.in_ == params.ParamTypes.cookie
-        ), f"non-body parameters must be in path, query, header or cookie: {param.name}"
+            field.schema.in_ == params.ParamTypes.cookie
+        ), f"non-body parameters must be in path, query, header or cookie: {field.name}"
         dependant.cookie_params.append(field)
 
 
-def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant) -> None:
-    default_value = Required
-    if not param.default == param.empty:
-        default_value = param.default
-    if isinstance(default_value, Schema):
-        schema = default_value
-        default_value = schema.default
-    else:
-        schema = Schema(default_value)
-    required = default_value == Required
-    annotation = get_annotation_from_schema(param.annotation, schema)
-    field = Field(
-        name=param.name,
-        type_=annotation,
-        default=None if required else default_value,
-        alias=schema.alias or param.name,
-        required=required,
-        model_config=BaseConfig,
-        class_validators={},
-        schema=schema,
-    )
-    dependant.body_params.append(field)
-
-
 def is_coroutine_callable(call: Callable) -> bool:
     if inspect.isfunction(call):
         return asyncio.iscoroutinefunction(call)
@@ -354,7 +349,7 @@ def request_params_to_args(
         if field.shape in sequence_shapes and isinstance(
             received_params, (QueryParams, Headers)
         ):
-            value = received_params.getlist(field.alias)
+            value = received_params.getlist(field.alias) or field.default
         else:
             value = received_params.get(field.alias)
         schema: params.Param = field.schema
index 87e223cb6adb3995df528f0fd6ba90e7b55c222f..26d491beaeedb6eeeb45f83ec9bb700608d64771 100644 (file)
@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast
 
 from fastapi import routing
 from fastapi.dependencies.models import Dependant
@@ -9,7 +9,7 @@ from fastapi.openapi.models import OpenAPI
 from fastapi.params import Body, Param
 from fastapi.utils import get_flat_models_from_routes, get_model_definitions
 from pydantic.fields import Field
-from pydantic.schema import Schema, field_schema, get_model_name_map
+from pydantic.schema import field_schema, get_model_name_map
 from pydantic.utils import lenient_issubclass
 from starlette.responses import JSONResponse
 from starlette.routing import BaseRoute
@@ -97,12 +97,8 @@ def get_openapi_operation_request_body(
     body_schema, _ = field_schema(
         body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
     )
-    schema: Schema = body_field.schema
-    if isinstance(schema, Body):
-        request_media_type = schema.media_type
-    else:
-        # Includes not declared media types (Schema)
-        request_media_type = "application/json"
+    body_field.schema = cast(Body, body_field.schema)
+    request_media_type = body_field.schema.media_type
     required = body_field.required
     request_body_oai: Dict[str, Any] = {}
     if required:
index e8be418388220e2ffe04c0969ca5ee6a618649ec..8700b0ef33062dcc84c14a807306ed4791bebd52 100644 (file)
@@ -20,7 +20,7 @@ classifiers = [
 ]
 requires = [
     "starlette >=0.11.1,<=0.12.0",
-    "pydantic >=0.17,<=0.26.0"
+    "pydantic >=0.26,<=0.26.0"
 ]
 description-file = "README.md"
 requires-python = ">=3.6"
diff --git a/tests/test_invalid_sequence_param.py b/tests/test_invalid_sequence_param.py
new file mode 100644 (file)
index 0000000..bdc4b1b
--- /dev/null
@@ -0,0 +1,29 @@
+from typing import List, Tuple
+
+import pytest
+from fastapi import FastAPI, Query
+from pydantic import BaseModel
+
+
+def test_invalid_sequence():
+    with pytest.raises(AssertionError):
+        app = FastAPI()
+
+        class Item(BaseModel):
+            title: str
+
+        @app.get("/items/")
+        def read_items(q: List[Item] = Query(None)):
+            pass  # pragma: no cover
+
+
+def test_invalid_tuple():
+    with pytest.raises(AssertionError):
+        app = FastAPI()
+
+        class Item(BaseModel):
+            title: str
+
+        @app.get("/items/")
+        def read_items(q: Tuple[Item, Item] = Query(None)):
+            pass  # pragma: no cover
diff --git a/tests/test_tutorial/test_path_params/test_tutorial005.py b/tests/test_tutorial/test_path_params/test_tutorial005.py
new file mode 100644 (file)
index 0000000..3245cdc
--- /dev/null
@@ -0,0 +1,120 @@
+import pytest
+from starlette.testclient import TestClient
+
+from path_params.tutorial005 import app
+
+client = TestClient(app)
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/model/{model_name}": {
+            "get": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                    "422": {
+                        "description": "Validation Error",
+                        "content": {
+                            "application/json": {
+                                "schema": {
+                                    "$ref": "#/components/schemas/HTTPValidationError"
+                                }
+                            }
+                        },
+                    },
+                },
+                "summary": "Get Model",
+                "operationId": "get_model_model__model_name__get",
+                "parameters": [
+                    {
+                        "required": True,
+                        "schema": {
+                            "title": "Model_Name",
+                            "enum": ["alexnet", "resnet", "lenet"],
+                        },
+                        "name": "model_name",
+                        "in": "path",
+                    }
+                ],
+            }
+        }
+    },
+    "components": {
+        "schemas": {
+            "ValidationError": {
+                "title": "ValidationError",
+                "required": ["loc", "msg", "type"],
+                "type": "object",
+                "properties": {
+                    "loc": {
+                        "title": "Location",
+                        "type": "array",
+                        "items": {"type": "string"},
+                    },
+                    "msg": {"title": "Message", "type": "string"},
+                    "type": {"title": "Error Type", "type": "string"},
+                },
+            },
+            "HTTPValidationError": {
+                "title": "HTTPValidationError",
+                "type": "object",
+                "properties": {
+                    "detail": {
+                        "title": "Detail",
+                        "type": "array",
+                        "items": {"$ref": "#/components/schemas/ValidationError"},
+                    }
+                },
+            },
+        }
+    },
+}
+
+
+def test_openapi():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema
+
+
+@pytest.mark.parametrize(
+    "url,status_code,expected",
+    [
+        (
+            "/model/alexnet",
+            200,
+            {"model_name": "alexnet", "message": "Deep Learning FTW!"},
+        ),
+        (
+            "/model/lenet",
+            200,
+            {"model_name": "lenet", "message": "LeCNN all the images"},
+        ),
+        (
+            "/model/resnet",
+            200,
+            {"model_name": "resnet", "message": "Have some residuals"},
+        ),
+        (
+            "/model/foo",
+            422,
+            {
+                "detail": [
+                    {
+                        "loc": ["path", "model_name"],
+                        "msg": "value is not a valid enumeration member",
+                        "type": "type_error.enum",
+                    }
+                ]
+            },
+        ),
+    ],
+)
+def test_get_enums(url, status_code, expected):
+    response = client.get(url)
+    assert response.status_code == status_code
+    assert response.json() == expected
diff --git a/tests/test_tutorial/test_query_params/test_tutorial007.py b/tests/test_tutorial/test_query_params/test_tutorial007.py
new file mode 100644 (file)
index 0000000..a0fb238
--- /dev/null
@@ -0,0 +1,95 @@
+from starlette.testclient import TestClient
+
+from query_params.tutorial007 import app
+
+client = TestClient(app)
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/items/{item_id}": {
+            "get": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                    "422": {
+                        "description": "Validation Error",
+                        "content": {
+                            "application/json": {
+                                "schema": {
+                                    "$ref": "#/components/schemas/HTTPValidationError"
+                                }
+                            }
+                        },
+                    },
+                },
+                "summary": "Read User Item",
+                "operationId": "read_user_item_items__item_id__get",
+                "parameters": [
+                    {
+                        "required": True,
+                        "schema": {"title": "Item_Id", "type": "string"},
+                        "name": "item_id",
+                        "in": "path",
+                    },
+                    {
+                        "required": False,
+                        "schema": {"title": "Limit", "type": "integer"},
+                        "name": "limit",
+                        "in": "query",
+                    },
+                ],
+            }
+        }
+    },
+    "components": {
+        "schemas": {
+            "ValidationError": {
+                "title": "ValidationError",
+                "required": ["loc", "msg", "type"],
+                "type": "object",
+                "properties": {
+                    "loc": {
+                        "title": "Location",
+                        "type": "array",
+                        "items": {"type": "string"},
+                    },
+                    "msg": {"title": "Message", "type": "string"},
+                    "type": {"title": "Error Type", "type": "string"},
+                },
+            },
+            "HTTPValidationError": {
+                "title": "HTTPValidationError",
+                "type": "object",
+                "properties": {
+                    "detail": {
+                        "title": "Detail",
+                        "type": "array",
+                        "items": {"$ref": "#/components/schemas/ValidationError"},
+                    }
+                },
+            },
+        }
+    },
+}
+
+
+def test_openapi():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema
+
+
+def test_read_item():
+    response = client.get("/items/foo")
+    assert response.status_code == 200
+    assert response.json() == {"item_id": "foo", "limit": None}
+
+
+def test_read_item_query():
+    response = client.get("/items/foo?limit=5")
+    assert response.status_code == 200
+    assert response.json() == {"item_id": "foo", "limit": 5}
diff --git a/tests/test_tutorial/test_query_params_str_validations/test_tutorial012.py b/tests/test_tutorial/test_query_params_str_validations/test_tutorial012.py
new file mode 100644 (file)
index 0000000..1e00c50
--- /dev/null
@@ -0,0 +1,96 @@
+from starlette.testclient import TestClient
+
+from query_params_str_validations.tutorial012 import app
+
+client = TestClient(app)
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/items/": {
+            "get": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                    "422": {
+                        "description": "Validation Error",
+                        "content": {
+                            "application/json": {
+                                "schema": {
+                                    "$ref": "#/components/schemas/HTTPValidationError"
+                                }
+                            }
+                        },
+                    },
+                },
+                "summary": "Read Items",
+                "operationId": "read_items_items__get",
+                "parameters": [
+                    {
+                        "required": False,
+                        "schema": {
+                            "title": "Q",
+                            "type": "array",
+                            "items": {"type": "string"},
+                            "default": ["foo", "bar"],
+                        },
+                        "name": "q",
+                        "in": "query",
+                    }
+                ],
+            }
+        }
+    },
+    "components": {
+        "schemas": {
+            "ValidationError": {
+                "title": "ValidationError",
+                "required": ["loc", "msg", "type"],
+                "type": "object",
+                "properties": {
+                    "loc": {
+                        "title": "Location",
+                        "type": "array",
+                        "items": {"type": "string"},
+                    },
+                    "msg": {"title": "Message", "type": "string"},
+                    "type": {"title": "Error Type", "type": "string"},
+                },
+            },
+            "HTTPValidationError": {
+                "title": "HTTPValidationError",
+                "type": "object",
+                "properties": {
+                    "detail": {
+                        "title": "Detail",
+                        "type": "array",
+                        "items": {"$ref": "#/components/schemas/ValidationError"},
+                    }
+                },
+            },
+        }
+    },
+}
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema
+
+
+def test_default_query_values():
+    url = "/items/"
+    response = client.get(url)
+    assert response.status_code == 200
+    assert response.json() == {"q": ["foo", "bar"]}
+
+
+def test_multi_query_values():
+    url = "/items/?q=baz&q=foobar"
+    response = client.get(url)
+    assert response.status_code == 200
+    assert response.json() == {"q": ["baz", "foobar"]}