]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:bug: Fix handling additional responses in include_router (#140)
authorSebastián Ramírez <tiangolo@gmail.com>
Fri, 5 Apr 2019 16:06:40 +0000 (20:06 +0400)
committerGitHub <noreply@github.com>
Fri, 5 Apr 2019 16:06:40 +0000 (20:06 +0400)
fastapi/routing.py
tests/test_additional_responses_router.py [new file with mode: 0644]

index e768c3ad3d1cb667fa2094ff0d0ead01a1a32a7e..2bdf46ddc3c7dde48b2841c522de76d4abe89700 100644 (file)
@@ -285,11 +285,11 @@ class APIRouter(routing.Router):
             assert not prefix.endswith(
                 "/"
             ), "A path prefix must not end with '/', as the routes will start with '/'"
+        if responses is None:
+            responses = {}
         for route in router.routes:
             if isinstance(route, APIRoute):
-                if responses is None:
-                    responses = {}
-                responses = {**responses, **route.responses}
+                combined_responses = {**responses, **route.responses}
                 self.add_api_route(
                     prefix + route.path,
                     route.endpoint,
@@ -299,7 +299,7 @@ class APIRouter(routing.Router):
                     summary=route.summary,
                     description=route.description,
                     response_description=route.response_description,
-                    responses=responses,
+                    responses=combined_responses,
                     deprecated=route.deprecated,
                     methods=route.methods,
                     operation_id=route.operation_id,
diff --git a/tests/test_additional_responses_router.py b/tests/test_additional_responses_router.py
new file mode 100644 (file)
index 0000000..e97bb62
--- /dev/null
@@ -0,0 +1,95 @@
+from fastapi import APIRouter, FastAPI
+from starlette.testclient import TestClient
+
+app = FastAPI()
+router = APIRouter()
+
+
+@router.get("/a", responses={501: {"description": "Error 1"}})
+async def a():
+    return "a"
+
+
+@router.get("/b", responses={502: {"description": "Error 2"}})
+async def b():
+    return "b"
+
+
+@router.get("/c", responses={501: {"description": "Error 3"}})
+async def c():
+    return "c"
+
+
+app.include_router(router)
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/a": {
+            "get": {
+                "responses": {
+                    "501": {"description": "Error 1"},
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                },
+                "summary": "A Get",
+                "operationId": "a_a_get",
+            }
+        },
+        "/b": {
+            "get": {
+                "responses": {
+                    "502": {"description": "Error 2"},
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                },
+                "summary": "B Get",
+                "operationId": "b_b_get",
+            }
+        },
+        "/c": {
+            "get": {
+                "responses": {
+                    "501": {"description": "Error 3"},
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                },
+                "summary": "C Get",
+                "operationId": "c_c_get",
+            }
+        },
+    },
+}
+
+client = TestClient(app)
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema
+
+
+def test_a():
+    response = client.get("/a")
+    assert response.status_code == 200
+    assert response.json() == "a"
+
+
+def test_b():
+    response = client.get("/b")
+    assert response.status_code == 200
+    assert response.json() == "b"
+
+
+def test_c():
+    response = client.get("/c")
+    assert response.status_code == 200
+    assert response.json() == "c"