]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:white_check_mark: Add first tests, for path and query
authorSebastián Ramírez <tiangolo@gmail.com>
Sat, 8 Dec 2018 07:56:07 +0000 (11:56 +0400)
committerSebastián Ramírez <tiangolo@gmail.com>
Sat, 8 Dec 2018 07:56:07 +0000 (11:56 +0400)
tests/__init__.py [new file with mode: 0644]
tests/endpoints/__init__.py [new file with mode: 0644]
tests/endpoints/a.py [new file with mode: 0644]
tests/endpoints/b.py [new file with mode: 0644]
tests/main.py [new file with mode: 0644]
tests/test_path.py [new file with mode: 0644]
tests/test_query.py [new file with mode: 0644]

diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/tests/endpoints/__init__.py b/tests/endpoints/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/tests/endpoints/a.py b/tests/endpoints/a.py
new file mode 100644 (file)
index 0000000..be5663f
--- /dev/null
@@ -0,0 +1,13 @@
+from fastapi.routing import APIRouter
+
+router = APIRouter()
+
+
+@router.get("/dog")
+def get_a_dog():
+    return "Woof"
+
+
+@router.get("/cat")
+def get_a_cat():
+    return "Meow"
diff --git a/tests/endpoints/b.py b/tests/endpoints/b.py
new file mode 100644 (file)
index 0000000..7747fb2
--- /dev/null
@@ -0,0 +1,13 @@
+from fastapi.routing import APIRouter
+
+router = APIRouter()
+
+
+@router.get("/dog")
+def get_b_dog():
+    return "B Woof"
+
+
+@router.get("/cat")
+def get_b_cat():
+    return "B Meow"
diff --git a/tests/main.py b/tests/main.py
new file mode 100644 (file)
index 0000000..4d2b199
--- /dev/null
@@ -0,0 +1,350 @@
+from fastapi.applications import FastAPI
+from fastapi.params import (
+    Body,
+    Cookie,
+    Depends,
+    File,
+    Form,
+    Header,
+    Param,
+    Path,
+    Query,
+    Security,
+)
+from fastapi.security.http import HTTPBasic
+from fastapi.security.oauth2 import (
+    OAuth2,
+    OAuth2PasswordRequestData,
+    OAuth2PasswordRequestForm,
+)
+from pydantic import BaseModel
+from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse
+from starlette.status import HTTP_202_ACCEPTED
+from starlette.testclient import TestClient
+
+from .endpoints.a import router as router_a
+from .endpoints.b import router as router_b
+
+app = FastAPI()
+
+
+app.include_router(router_a)
+app.include_router(router_b, prefix="/b")
+
+
+@app.get("/text")
+def get_text():
+    return "Hello World"
+
+
+@app.get("/path/{item_id}")
+def get_id(item_id):
+    return item_id
+
+
+@app.get("/path/str/{item_id}")
+def get_str_id(item_id: str):
+    return item_id
+
+
+@app.get("/path/int/{item_id}")
+def get_int_id(item_id: int):
+    return item_id
+
+
+@app.get("/path/float/{item_id}")
+def get_float_id(item_id: float):
+    return item_id
+
+
+@app.get("/path/bool/{item_id}")
+def get_bool_id(item_id: bool):
+    return item_id
+
+
+@app.get("/path/param/{item_id}")
+def get_path_param_id(item_id: str = Path(None)):
+    return item_id
+
+
+@app.get("/path/param-required/{item_id}")
+def get_path_param_required_id(item_id: str = Path(...)):
+    return item_id
+
+
+@app.get("/query")
+def get_query(query):
+    if query is None:
+        return "foo bar"
+    return f"foo bar {query}"
+
+
+@app.get("/query/optional")
+def get_query_optional(query=None):
+    if query is None:
+        return "foo bar"
+    return f"foo bar {query}"
+
+
+@app.get("/query/int")
+def get_query_type(query: int):
+    if query is None:
+        return "foo bar"
+    return f"foo bar {query}"
+
+
+@app.get("/query/int/optional")
+def get_query_type_optional(query: int = None):
+    if query is None:
+        return "foo bar"
+    return f"foo bar {query}"
+
+
+@app.get("/query/param")
+def get_query_param(query=Query(None)):
+    if query is None:
+        return "foo bar"
+    return f"foo bar {query}"
+
+
+@app.get("/query/param-required")
+def get_query_param_required(query=Query(...)):
+    if query is None:
+        return "foo bar"
+    return f"foo bar {query}"
+
+
+@app.get("/query/param-required/int")
+def get_query_param_required_type(query: int = Query(...)):
+    if query is None:
+        return "foo bar"
+    return f"foo bar {query}"
+
+
+@app.get("/cookie")
+def get_cookie(coo=Cookie(None)):
+    return coo
+
+
+@app.get("/header")
+def get_header(head_name=Header(None)):
+    return head_name
+
+
+@app.get("/header_under")
+def get_header(head_name=Header(None, convert_underscores=False)):
+    return head_name
+
+
+@app.get("/param")
+def get_param(par=Param(None)):
+    return par
+
+
+@app.get("/security")
+def get_security(sec=Security(HTTPBasic())):
+    return sec
+
+
+reusable_oauth2 = OAuth2(
+    flows={
+        "password": {
+            "tokenUrl": "/token",
+            "scopes": {"read:user": "Read a User", "write:user": "Create a user"},
+        }
+    }
+)
+
+
+@app.get("/security/oauth2")
+def get_security_oauth2(sec=Security(reusable_oauth2, scopes=["read:user"])):
+    return sec
+
+
+@app.post("/token")
+def post_token(request_data: OAuth2PasswordRequestForm = Form(...)):
+    print(request_data)
+    data = request_data.parse()
+    print(data)
+
+    print(request_data())
+    access_token = request_data.username + ":" + request_data.password
+    return {"access_token": access_token}
+
+
+class Item(BaseModel):
+    name: str
+    price: float
+    is_offer: bool
+
+
+@app.put("/items/{item_id}")
+def put_item(item_id: str, item: Item):
+    return item
+
+
+@app.post("/items/")
+def post_item(item: Item):
+    return item
+
+
+@app.post("/items-all-params/{item_id}")
+def post_items_all_params(
+    item_id: str = Path(...),
+    body: Item = Body(...),
+    query_a: int = Query(None),
+    query_b=Query(None),
+    coo: str = Cookie(None),
+    x_head: int = Header(None),
+    x_under: str = Header(None, convert_underscores=False),
+):
+    return {
+        "item_id": item_id,
+        "body": body,
+        "query_a": query_a,
+        "query_b": query_b,
+        "coo": coo,
+        "x_head": x_head,
+        "x_under": x_under,
+    }
+
+
+@app.post("/items-all-params-defaults/{item_id}")
+def post_items_all_params_default(
+    item_id: str,
+    body_item_a: Item,
+    body_item_b: Item,
+    query_a: int,
+    query_b: int,
+    coo: str = Cookie(None),
+    x_head: int = Header(None),
+    x_under: str = Header(None, convert_underscores=False),
+):
+    return {
+        "item_id": item_id,
+        "body_item_a": body_item_a,
+        "body_item_b": body_item_b,
+        "query_a": query_a,
+        "query_b": query_b,
+        "coo": coo,
+        "x_head": x_head,
+        "x_under": x_under,
+    }
+
+
+@app.delete("/items/{item_id}")
+def delete_item(item_id: str):
+    return item_id
+
+
+@app.options("/options/")
+def options():
+    return JSONResponse(headers={"x-fastapi": "fast"})
+
+
+@app.head("/head/")
+def head():
+    return {"not sent": "nope"}
+
+
+@app.patch("/patch/{user_id}")
+def patch(user_id: str, increment: float):
+    return {"user_id": user_id, "total": 5 + increment}
+
+
+@app.trace("/trace/")
+def trace():
+    return PlainTextResponse(media_type="message/http")
+
+
+@app.get("/model", response_model=Item, status_code=HTTP_202_ACCEPTED)
+def model():
+    return {"name": "Foo", "price": "5.0", "password": "not sent"}
+
+
+@app.get(
+    "/metadata",
+    tags=["tag1", "tag2"],
+    summary="The summary",
+    description="The description",
+    response_description="Response description",
+    deprecated=True,
+    operation_id="a_very_long_and_strange_operation_id",
+)
+def get_meta():
+    return "Foo"
+
+
+@app.get("/html", content_type=HTMLResponse)
+def get_html():
+    return """
+    <html>
+    <body>
+    <h1>
+    Some text inside
+    </h1>
+    </body>
+    </html>
+    """
+
+
+class FakeDB:
+    def __init__(self):
+        self.data = {
+            "johndoe": {
+                "username": "johndoe",
+                "password": "shouldbehashed",
+                "fist_name": "John",
+                "last_name": "Doe",
+            }
+        }
+
+
+class DBConnectionManager:
+    def __init__(self):
+        self.db = FakeDB()
+
+    def __call__(self):
+        return self.db
+
+
+connection_manager = DBConnectionManager()
+
+
+class TokenUserData(BaseModel):
+    username: str
+    password: str
+
+
+class UserInDB(BaseModel):
+    username: str
+    password: str
+    fist_name: str
+    last_name: str
+
+
+def require_token(
+    token: str = Security(reusable_oauth2, scopes=["read:user", "write:user"])
+):
+    raw_token = token.replace("Bearer ", "")
+    # Never do this plaintext password usage in production
+    username, password = raw_token.split(":")
+    return TokenUserData(username=username, password=password)
+
+
+def require_user(
+    db: FakeDB = Depends(connection_manager),
+    user_data: TokenUserData = Depends(require_token),
+):
+    return db.data[user_data.username]
+
+
+class UserOut(BaseModel):
+    username: str
+    fist_name: str
+    last_name: str
+
+
+@app.get("/dependency", response_model=UserOut)
+def get_dependency(user: UserInDB = Depends(require_user)):
+    return user
diff --git a/tests/test_path.py b/tests/test_path.py
new file mode 100644 (file)
index 0000000..f271331
--- /dev/null
@@ -0,0 +1,73 @@
+import pytest
+from starlette.testclient import TestClient
+
+from .main import app
+
+client = TestClient(app)
+
+
+def test_text_get():
+    response = client.get("/text")
+    assert response.status_code == 200
+    assert response.json() == "Hello World"
+
+
+def test_nonexistent():
+    response = client.get("/nonexistent")
+    assert response.status_code == 404
+    assert response.json() == {"detail": "Not Found"}
+
+
+response_not_valid_int = {
+    "detail": [
+        {
+            "loc": ["path", "item_id"],
+            "msg": "value is not a valid integer",
+            "type": "type_error.integer",
+        }
+    ]
+}
+
+response_not_valid_float = {
+    "detail": [
+        {
+            "loc": ["path", "item_id"],
+            "msg": "value is not a valid float",
+            "type": "type_error.float",
+        }
+    ]
+}
+
+
+@pytest.mark.parametrize(
+    "path,expected_status,expected_response",
+    [
+        ("/path/foobar", 200, "foobar"),
+        ("/path/str/foobar", 200, "foobar"),
+        ("/path/str/42", 200, "42"),
+        ("/path/str/True", 200, "True"),
+        ("/path/int/foobar", 422, response_not_valid_int),
+        ("/path/int/True", 422, response_not_valid_int),
+        ("/path/int/42", 200, 42),
+        ("/path/int/42.5", 422, response_not_valid_int),
+        ("/path/float/foobar", 422, response_not_valid_float),
+        ("/path/float/True", 422, response_not_valid_float),
+        ("/path/float/42", 200, 42),
+        ("/path/float/42.5", 200, 42.5),
+        ("/path/bool/foobar", 200, False),
+        ("/path/bool/True", 200, True),
+        ("/path/bool/42", 200, False),
+        ("/path/bool/42.5", 200, False),
+        ("/path/bool/1", 200, True),
+        ("/path/bool/0", 200, False),
+        ("/path/bool/true", 200, True),
+        ("/path/bool/False", 200, False),
+        ("/path/bool/false", 200, False),
+        ("/path/param/foo", 200, "foo"),
+        ("/path/param-required/foo", 200, "foo"),
+    ],
+)
+def test_get_path(path, expected_status, expected_response):
+    response = client.get(path)
+    assert response.status_code == expected_status
+    assert response.json() == expected_response
diff --git a/tests/test_query.py b/tests/test_query.py
new file mode 100644 (file)
index 0000000..fc792b8
--- /dev/null
@@ -0,0 +1,44 @@
+import pytest
+from starlette.testclient import TestClient
+
+from .main import app
+
+client = TestClient(app)
+
+response_missing = {
+    "detail": [
+        {"loc": ["query"], "msg": "field required", "type": "value_error.missing"}
+    ]
+}
+
+response_not_valid_int = {
+    "detail": [
+        {
+            "loc": ["query", "query"],
+            "msg": "value is not a valid integer",
+            "type": "type_error.integer",
+        }
+    ]
+}
+
+
+@pytest.mark.parametrize(
+    "path,expected_status,expected_response",
+    [
+        ("/query", 422, response_missing),
+        ("/query?query=baz", 200, "foo bar baz"),
+        ("/query?not_declared=baz", 422, response_missing),
+        ("/query/optional", 200, "foo bar"),
+        ("/query/optional?query=baz", 200, "foo bar baz"),
+        ("/query/optional?not_declared=baz", 200, "foo bar"),
+        ("/query/int", 422, response_missing),
+        ("/query/int?query=42", 200, "foo bar 42"),
+        ("/query/int?query=42.5", 422, response_not_valid_int),
+        ("/query/int?query=baz", 422, response_not_valid_int),
+        ("/query/int?not_declared=baz", 422, response_missing),
+    ],
+)
+def test_get_path(path, expected_status, expected_response):
+    response = client.get(path)
+    assert response.status_code == expected_status
+    assert response.json() == expected_response