]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:sparkles: Allow defaults in path parameters (and don't use them) (#450) (#464)
authorjonathanunderwood <jonathan.underwood@gmail.com>
Sun, 29 Sep 2019 22:03:16 +0000 (23:03 +0100)
committerSebastián Ramírez <tiangolo@gmail.com>
Sun, 29 Sep 2019 22:03:16 +0000 (17:03 -0500)
This allows using parameters that can have defaults (e.g. `None`) that can be used as query parameters.

But can also be used in routers with that include those parameters as part of the path.

fastapi/dependencies/utils.py
tests/test_infer_param_optionality.py [new file with mode: 0644]

index 852f1e0253dee82483ab157c76c676cd201c7d99..d5d1145653d1ad092261f7a512d67d447ae535df 100644 (file)
@@ -220,16 +220,18 @@ def get_dependant(
             continue
         param_field = get_param_field(param=param, default_schema=params.Query)
         if param_name in path_param_names:
-            assert param.default == param.empty or isinstance(
-                param.default, params.Path
-            ), "Path params must have no defaults or use Path(...)"
             assert is_scalar_field(
                 field=param_field
             ), f"Path params must be of one of the supported types"
+            if isinstance(param.default, params.Path):
+                ignore_default = False
+            else:
+                ignore_default = True
             param_field = get_param_field(
                 param=param,
                 default_schema=params.Path,
                 force_type=params.ParamTypes.path,
+                ignore_default=ignore_default,
             )
             add_param_to_fields(field=param_field, dependant=dependant)
         elif is_scalar_field(field=param_field):
@@ -272,10 +274,11 @@ def get_param_field(
     param: inspect.Parameter,
     default_schema: Type[params.Param] = params.Param,
     force_type: params.ParamTypes = None,
+    ignore_default: bool = False,
 ) -> Field:
     default_value = Required
     had_schema = False
-    if not param.default == param.empty:
+    if not param.default == param.empty and ignore_default is False:
         default_value = param.default
     if isinstance(default_value, Schema):
         had_schema = True
diff --git a/tests/test_infer_param_optionality.py b/tests/test_infer_param_optionality.py
new file mode 100644 (file)
index 0000000..79fa716
--- /dev/null
@@ -0,0 +1,136 @@
+from fastapi import APIRouter, FastAPI
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+
+user_router = APIRouter()
+item_router = APIRouter()
+
+
+@user_router.get("/")
+def get_users():
+    return [{"user_id": "u1"}, {"user_id": "u2"}]
+
+
+@user_router.get("/{user_id}")
+def get_user(user_id: str):
+    return {"user_id": user_id}
+
+
+@item_router.get("/")
+def get_items(user_id: str = None):
+    if user_id is None:
+        return [{"item_id": "i1", "user_id": "u1"}, {"item_id": "i2", "user_id": "u2"}]
+    else:
+        return [{"item_id": "i2", "user_id": user_id}]
+
+
+@item_router.get("/{item_id}")
+def get_item(item_id: str, user_id: str = None):
+    if user_id is None:
+        return {"item_id": item_id}
+    else:
+        return {"item_id": item_id, "user_id": user_id}
+
+
+app.include_router(user_router, prefix="/users")
+app.include_router(item_router, prefix="/items")
+
+app.include_router(item_router, prefix="/users/{user_id}/items")
+
+
+client = TestClient(app)
+
+
+def test_get_users():
+    """Check that /users returns expected data"""
+    response = client.get("/users")
+    assert response.status_code == 200
+    assert response.json() == [{"user_id": "u1"}, {"user_id": "u2"}]
+
+
+def test_get_user():
+    """Check that /users/{user_id} returns expected data"""
+    response = client.get("/users/abc123")
+    assert response.status_code == 200
+    assert response.json() == {"user_id": "abc123"}
+
+
+def test_get_items_1():
+    """Check that /items returns expected data"""
+    response = client.get("/items")
+    assert response.status_code == 200
+    assert response.json() == [
+        {"item_id": "i1", "user_id": "u1"},
+        {"item_id": "i2", "user_id": "u2"},
+    ]
+
+
+def test_get_items_2():
+    """Check that /items returns expected data with user_id specified"""
+    response = client.get("/items?user_id=abc123")
+    assert response.status_code == 200
+    assert response.json() == [{"item_id": "i2", "user_id": "abc123"}]
+
+
+def test_get_item_1():
+    """Check that /items/{item_id} returns expected data"""
+    response = client.get("/items/item01")
+    assert response.status_code == 200
+    assert response.json() == {"item_id": "item01"}
+
+
+def test_get_item_2():
+    """Check that /items/{item_id} returns expected data with user_id specified"""
+    response = client.get("/items/item01?user_id=abc123")
+    assert response.status_code == 200
+    assert response.json() == {"item_id": "item01", "user_id": "abc123"}
+
+
+def test_get_users_items():
+    """Check that /users/{user_id}/items returns expected data"""
+    response = client.get("/users/abc123/items")
+    assert response.status_code == 200
+    assert response.json() == [{"item_id": "i2", "user_id": "abc123"}]
+
+
+def test_get_users_item():
+    """Check that /users/{user_id}/items returns expected data"""
+    response = client.get("/users/abc123/items/item01")
+    assert response.status_code == 200
+    assert response.json() == {"item_id": "item01", "user_id": "abc123"}
+
+
+def test_schema_1():
+    """Check that the user_id is a required path parameter under /users"""
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    r = response.json()
+
+    d = {
+        "required": True,
+        "schema": {"title": "User_Id", "type": "string"},
+        "name": "user_id",
+        "in": "path",
+    }
+
+    assert d in r["paths"]["/users/{user_id}"]["get"]["parameters"]
+    assert d in r["paths"]["/users/{user_id}/items/"]["get"]["parameters"]
+
+
+def test_schema_2():
+    """Check that the user_id is an optional query parameter under /items"""
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    r = response.json()
+
+    d = {
+        "required": False,
+        "schema": {"title": "User_Id", "type": "string"},
+        "name": "user_id",
+        "in": "query",
+    }
+
+    assert d in r["paths"]["/items/{item_id}"]["get"]["parameters"]
+    assert d in r["paths"]["/items/"]["get"]["parameters"]