]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:bug: Fix preserving route_class when calling include_router (#538)
authordmontagu <35119617+dmontagu@users.noreply.github.com>
Fri, 4 Oct 2019 21:35:20 +0000 (14:35 -0700)
committerSebastián Ramírez <tiangolo@gmail.com>
Fri, 4 Oct 2019 21:35:20 +0000 (16:35 -0500)
fastapi/routing.py
tests/test_custom_route_class.py [new file with mode: 0644]

index 8f61ea50ca484aa86662bc41756fd5a6c12cd0e5..b0902310c7c9dadabbcde3fb0e79ee4209e78a2d 100644 (file)
@@ -348,8 +348,10 @@ class APIRouter(routing.Router):
         include_in_schema: bool = True,
         response_class: Type[Response] = None,
         name: str = None,
+        route_class_override: Optional[Type[APIRoute]] = None,
     ) -> None:
-        route = self.route_class(
+        route_class = route_class_override or self.route_class
+        route = route_class(
             path,
             endpoint=endpoint,
             response_model=response_model,
@@ -487,6 +489,7 @@ class APIRouter(routing.Router):
                     include_in_schema=route.include_in_schema,
                     response_class=route.response_class or default_response_class,
                     name=route.name,
+                    route_class_override=type(route),
                 )
             elif isinstance(route, routing.Route):
                 self.add_route(
diff --git a/tests/test_custom_route_class.py b/tests/test_custom_route_class.py
new file mode 100644 (file)
index 0000000..8bbf88a
--- /dev/null
@@ -0,0 +1,114 @@
+import pytest
+from fastapi import APIRouter, FastAPI
+from fastapi.routing import APIRoute
+from starlette.testclient import TestClient
+
+app = FastAPI()
+
+
+class APIRouteA(APIRoute):
+    x_type = "A"
+
+
+class APIRouteB(APIRoute):
+    x_type = "B"
+
+
+class APIRouteC(APIRoute):
+    x_type = "C"
+
+
+router_a = APIRouter(route_class=APIRouteA)
+router_b = APIRouter(route_class=APIRouteB)
+router_c = APIRouter(route_class=APIRouteC)
+
+
+@router_a.get("/")
+def get_a():
+    return {"msg": "A"}
+
+
+@router_b.get("/")
+def get_b():
+    return {"msg": "B"}
+
+
+@router_c.get("/")
+def get_c():
+    return {"msg": "C"}
+
+
+router_b.include_router(router=router_c, prefix="/c")
+router_a.include_router(router=router_b, prefix="/b")
+app.include_router(router=router_a, prefix="/a")
+
+
+client = TestClient(app)
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/a/": {
+            "get": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    }
+                },
+                "summary": "Get A",
+                "operationId": "get_a_a__get",
+            }
+        },
+        "/a/b/": {
+            "get": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    }
+                },
+                "summary": "Get B",
+                "operationId": "get_b_a_b__get",
+            }
+        },
+        "/a/b/c/": {
+            "get": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    }
+                },
+                "summary": "Get C",
+                "operationId": "get_c_a_b_c__get",
+            }
+        },
+    },
+}
+
+
+@pytest.mark.parametrize(
+    "path,expected_status,expected_response",
+    [
+        ("/a", 200, {"msg": "A"}),
+        ("/a/b", 200, {"msg": "B"}),
+        ("/a/b/c", 200, {"msg": "C"}),
+        ("/openapi.json", 200, openapi_schema),
+    ],
+)
+def test_get_path(path, expected_status, expected_response):
+    response = client.get(path)
+    assert response.status_code == expected_status
+    assert response.json() == expected_response
+
+
+def test_route_classes():
+    routes = {}
+    r: APIRoute
+    for r in app.router.routes:
+        routes[r.path] = r
+    assert routes["/a/"].x_type == "A"
+    assert routes["/a/b/"].x_type == "B"
+    assert routes["/a/b/c/"].x_type == "C"