]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Remove converter from path when generating `OpenAPI` schema (#1648)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Tue, 28 Jun 2022 05:23:27 +0000 (07:23 +0200)
committerGitHub <noreply@github.com>
Tue, 28 Jun 2022 05:23:27 +0000 (07:23 +0200)
* Remove converter from path when generating `OpenAPI` schema

* Update starlette/schemas.py

Co-authored-by: Tom Christie <tom@tomchristie.com>
Co-authored-by: Tom Christie <tom@tomchristie.com>
starlette/schemas.py
tests/test_schemas.py

index 6ca764fdcc738be284e23e61f43d27b621e6d6a2..55bf7b397613b5f356a2c16dce068e26559a1ae1 100644 (file)
@@ -1,4 +1,5 @@
 import inspect
+import re
 import typing
 
 from starlette.requests import Request
@@ -49,10 +50,11 @@ class BaseSchemaGenerator:
 
         for route in routes:
             if isinstance(route, Mount):
+                path = self._remove_converter(route.path)
                 routes = route.routes or []
                 sub_endpoints = [
                     EndpointInfo(
-                        path="".join((route.path, sub_endpoint.path)),
+                        path="".join((path, sub_endpoint.path)),
                         http_method=sub_endpoint.http_method,
                         func=sub_endpoint.func,
                     )
@@ -64,23 +66,32 @@ class BaseSchemaGenerator:
                 continue
 
             elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
+                path = self._remove_converter(route.path)
                 for method in route.methods or ["GET"]:
                     if method == "HEAD":
                         continue
                     endpoints_info.append(
-                        EndpointInfo(route.path, method.lower(), route.endpoint)
+                        EndpointInfo(path, method.lower(), route.endpoint)
                     )
             else:
+                path = self._remove_converter(route.path)
                 for method in ["get", "post", "put", "patch", "delete", "options"]:
                     if not hasattr(route.endpoint, method):
                         continue
                     func = getattr(route.endpoint, method)
-                    endpoints_info.append(
-                        EndpointInfo(route.path, method.lower(), func)
-                    )
+                    endpoints_info.append(EndpointInfo(path, method.lower(), func))
 
         return endpoints_info
 
+    def _remove_converter(self, path: str) -> str:
+        """
+        Remove the converter from the path.
+        For example, a route like this:
+            Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
+        Should be represented as `/users/{id}` in the OpenAPI schema.
+        """
+        return re.sub(r":\w+}", "}", path)
+
     def parse_docstring(self, func_or_method: typing.Callable) -> dict:
         """
         Given a function, parse the docstring as YAML and return a dictionary of info.
index fa43785b98fd8c4c511353d20bcb1bead7aec920..26884b3916a6b3fa11ced8fca5d59dad1b79cba2 100644 (file)
@@ -13,6 +13,17 @@ def ws(session):
     pass  # pragma: no cover
 
 
+def get_user(request):
+    """
+    responses:
+        200:
+            description: A user.
+            examples:
+                {"username": "tom"}
+    """
+    pass  # pragma: no cover
+
+
 def list_users(request):
     """
     responses:
@@ -103,6 +114,7 @@ subapp = Starlette(
 app = Starlette(
     routes=[
         WebSocketRoute("/ws", endpoint=ws),
+        Route("/users/{id:int}", endpoint=get_user, methods=["GET"]),
         Route("/users", endpoint=list_users, methods=["GET", "HEAD"]),
         Route("/users", endpoint=create_user, methods=["POST"]),
         Route("/orgs", endpoint=OrganisationsEndpoint),
@@ -168,6 +180,16 @@ def test_schema_generation():
                     }
                 },
             },
+            "/users/{id}": {
+                "get": {
+                    "responses": {
+                        200: {
+                            "description": "A user.",
+                            "examples": {"username": "tom"},
+                        }
+                    }
+                },
+            },
         },
     }
 
@@ -216,6 +238,13 @@ paths:
           description: A user.
           examples:
             username: tom
+  /users/{id}:
+    get:
+      responses:
+        200:
+          description: A user.
+          examples:
+            username: tom
 """