def setup(self) -> None:
if self.openapi_url:
- urls = (server_data.get("url") for server_data in self.servers)
- server_urls = {url for url in urls if url}
async def openapi(req: Request) -> JSONResponse:
root_path = req.scope.get("root_path", "").rstrip("/")
- if root_path not in server_urls:
- if root_path and self.root_path_in_servers:
- self.servers.insert(0, {"url": root_path})
- server_urls.add(root_path)
- return JSONResponse(self.openapi())
+ schema = self.openapi()
+ if root_path and self.root_path_in_servers:
+ server_urls = {s.get("url") for s in schema.get("servers", [])}
+ if root_path not in server_urls:
+ schema = dict(schema)
+ schema["servers"] = [{"url": root_path}] + schema.get(
+ "servers", []
+ )
+ return JSONResponse(schema)
self.add_route(self.openapi_url, openapi, include_in_schema=False)
if self.openapi_url and self.docs_url:
from fastapi.encoders import jsonable_encoder
from starlette.responses import HTMLResponse
+
+def _html_safe_json(value: Any) -> str:
+ """Serialize a value to JSON with HTML special characters escaped.
+
+ This prevents injection when the JSON is embedded inside a <script> tag.
+ """
+ return (
+ json.dumps(value)
+ .replace("<", "\\u003c")
+ .replace(">", "\\u003e")
+ .replace("&", "\\u0026")
+ )
+
+
swagger_ui_default_parameters: Annotated[
dict[str, Any],
Doc(
"""
for key, value in current_swagger_ui_parameters.items():
- html += f"{json.dumps(key)}: {json.dumps(jsonable_encoder(value))},\n"
+ html += f"{_html_safe_json(key)}: {_html_safe_json(jsonable_encoder(value))},\n"
if oauth2_redirect_url:
html += f"oauth2RedirectUrl: window.location.origin + '{oauth2_redirect_url}',"
if init_oauth:
html += f"""
- ui.initOAuth({json.dumps(jsonable_encoder(init_oauth))})
+ ui.initOAuth({_html_safe_json(jsonable_encoder(init_oauth))})
"""
html += """
--- /dev/null
+from fastapi import FastAPI
+from fastapi.testclient import TestClient
+
+
+def test_root_path_does_not_persist_across_requests():
+ app = FastAPI()
+
+ @app.get("/")
+ def read_root(): # pragma: no cover
+ return {"ok": True}
+
+ # Attacker request with a spoofed root_path
+ attacker_client = TestClient(app, root_path="/evil-api")
+ response1 = attacker_client.get("/openapi.json")
+ data1 = response1.json()
+ assert any(s.get("url") == "/evil-api" for s in data1.get("servers", []))
+
+ # Subsequent legitimate request with no root_path
+ clean_client = TestClient(app)
+ response2 = clean_client.get("/openapi.json")
+ data2 = response2.json()
+ servers = [s.get("url") for s in data2.get("servers", [])]
+ assert "/evil-api" not in servers
+
+
+def test_multiple_different_root_paths_do_not_accumulate():
+ app = FastAPI()
+
+ @app.get("/")
+ def read_root(): # pragma: no cover
+ return {"ok": True}
+
+ for prefix in ["/path-a", "/path-b", "/path-c"]:
+ c = TestClient(app, root_path=prefix)
+ c.get("/openapi.json")
+
+ # A clean request should not have any of them
+ clean_client = TestClient(app)
+ response = clean_client.get("/openapi.json")
+ data = response.json()
+ servers = [s.get("url") for s in data.get("servers", [])]
+ for prefix in ["/path-a", "/path-b", "/path-c"]:
+ assert prefix not in servers, (
+ f"root_path '{prefix}' leaked into clean request: {servers}"
+ )
+
+
+def test_legitimate_root_path_still_appears():
+ app = FastAPI()
+
+ @app.get("/")
+ def read_root(): # pragma: no cover
+ return {"ok": True}
+
+ client = TestClient(app, root_path="/api/v1")
+ response = client.get("/openapi.json")
+ data = response.json()
+ servers = [s.get("url") for s in data.get("servers", [])]
+ assert "/api/v1" in servers
+
+
+def test_configured_servers_not_mutated():
+ configured_servers = [{"url": "https://prod.example.com"}]
+ app = FastAPI(servers=configured_servers)
+
+ @app.get("/")
+ def read_root(): # pragma: no cover
+ return {"ok": True}
+
+ # Request with a rogue root_path
+ attacker_client = TestClient(app, root_path="/evil")
+ attacker_client.get("/openapi.json")
+
+ # The original servers list must be untouched
+ assert configured_servers == [{"url": "https://prod.example.com"}]
--- /dev/null
+from fastapi.openapi.docs import get_swagger_ui_html
+
+
+def test_init_oauth_html_chars_are_escaped():
+ xss_payload = "Evil</script><script>alert(1)</script>"
+ html = get_swagger_ui_html(
+ openapi_url="/openapi.json",
+ title="Test",
+ init_oauth={"appName": xss_payload},
+ )
+ body = html.body.decode()
+
+ assert "</script><script>" not in body
+ assert "\\u003c/script\\u003e\\u003cscript\\u003e" in body
+
+
+def test_swagger_ui_parameters_html_chars_are_escaped():
+ html = get_swagger_ui_html(
+ openapi_url="/openapi.json",
+ title="Test",
+ swagger_ui_parameters={"customKey": "<img src=x onerror=alert(1)>"},
+ )
+ body = html.body.decode()
+ assert "<img src=x onerror=alert(1)>" not in body
+ assert "\\u003cimg" in body
+
+
+def test_normal_init_oauth_still_works():
+ html = get_swagger_ui_html(
+ openapi_url="/openapi.json",
+ title="Test",
+ init_oauth={"clientId": "my-client", "appName": "My App"},
+ )
+ body = html.body.decode()
+ assert '"clientId": "my-client"' in body
+ assert '"appName": "My App"' in body
+ assert "ui.initOAuth" in body