]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Set `line-length` to 120 on Ruff (#2679)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Sun, 1 Sep 2024 13:11:01 +0000 (15:11 +0200)
committerGitHub <noreply@github.com>
Sun, 1 Sep 2024 13:11:01 +0000 (15:11 +0200)
* Set `line-length` to 120 on Ruff

* Add links to selected rules

* Remove empty strings

* Fix more stuff

53 files changed:
pyproject.toml
starlette/_compat.py
starlette/_exception_handler.py
starlette/_utils.py
starlette/applications.py
starlette/authentication.py
starlette/background.py
starlette/concurrency.py
starlette/config.py
starlette/datastructures.py
starlette/endpoints.py
starlette/formparsers.py
starlette/middleware/__init__.py
starlette/middleware/authentication.py
starlette/middleware/base.py
starlette/middleware/cors.py
starlette/middleware/errors.py
starlette/middleware/exceptions.py
starlette/middleware/gzip.py
starlette/middleware/sessions.py
starlette/middleware/trustedhost.py
starlette/middleware/wsgi.py
starlette/requests.py
starlette/responses.py
starlette/routing.py
starlette/schemas.py
starlette/staticfiles.py
starlette/templating.py
starlette/testclient.py
starlette/types.py
starlette/websockets.py
tests/middleware/test_base.py
tests/middleware/test_cors.py
tests/middleware/test_gzip.py
tests/middleware/test_session.py
tests/middleware/test_trusted_host.py
tests/test_applications.py
tests/test_authentication.py
tests/test_background.py
tests/test_config.py
tests/test_convertors.py
tests/test_datastructures.py
tests/test_endpoints.py
tests/test_exceptions.py
tests/test_formparsers.py
tests/test_responses.py
tests/test_routing.py
tests/test_schemas.py
tests/test_staticfiles.py
tests/test_status.py
tests/test_templates.py
tests/test_testclient.py
tests/test_websockets.py

index f2721c870f86ea11d122cc0a1a1a15d92bebe724..52156528e8cd86b0854a3dcb3eafd688e06f5321 100644 (file)
@@ -50,9 +50,19 @@ Source = "https://github.com/encode/starlette"
 [tool.hatch.version]
 path = "starlette/__init__.py"
 
+[tool.ruff]
+line-length = 120
+
 [tool.ruff.lint]
-select = ["E", "F", "I", "FA", "UP"]
-ignore = ["UP031"]
+select = [
+    "E",      # https://docs.astral.sh/ruff/rules/#error-e
+    "F",      # https://docs.astral.sh/ruff/rules/#pyflakes-f
+    "I",      # https://docs.astral.sh/ruff/rules/#isort-i
+    "FA",     # https://docs.astral.sh/ruff/rules/#flake8-future-annotations-fa
+    "UP",     # https://docs.astral.sh/ruff/rules/#pyupgrade-up
+    "RUF100", # https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
+]
+ignore = ["UP031"] # https://docs.astral.sh/ruff/rules/printf-string-formatting/
 
 [tool.ruff.lint.isort]
 combine-as-imports = true
index 9087a76450792f48450fbda2d9d063f8f3016689..718bc9020e41d50913195eef2bf56ea94322d591 100644 (file)
@@ -15,9 +15,7 @@ try:
     # that reject usedforsecurity=True
     hashlib.md5(b"data", usedforsecurity=False)  # type: ignore[call-arg]
 
-    def md5_hexdigest(
-        data: bytes, *, usedforsecurity: bool = True
-    ) -> str:  # pragma: no cover
+    def md5_hexdigest(data: bytes, *, usedforsecurity: bool = True) -> str:  # pragma: no cover
         return hashlib.md5(  # type: ignore[call-arg]
             data, usedforsecurity=usedforsecurity
         ).hexdigest()
index 99cb6b64cd9f5ef4d3064e34aa910d5f2a2a7367..4fbc86394d1cea6c1d8d2af9e7edd1afe24f4f3b 100644 (file)
@@ -22,9 +22,7 @@ ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
 StatusHandlers = typing.Dict[int, ExceptionHandler]
 
 
-def _lookup_exception_handler(
-    exc_handlers: ExceptionHandlers, exc: Exception
-) -> ExceptionHandler | None:
+def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None:
     for cls in type(exc).__mro__:
         if cls in exc_handlers:
             return exc_handlers[cls]
index b6970542b8a978c1455fe9178fa640b4c6b994f2..90bd346fd1211c99e0c0787940f6501263a7dc62 100644 (file)
@@ -37,26 +37,20 @@ def is_async_callable(obj: typing.Any) -> typing.Any:
     while isinstance(obj, functools.partial):
         obj = obj.func
 
-    return asyncio.iscoroutinefunction(obj) or (
-        callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
-    )
+    return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__))
 
 
 T_co = typing.TypeVar("T_co", covariant=True)
 
 
-class AwaitableOrContextManager(
-    typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]
-): ...
+class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]): ...
 
 
 class SupportsAsyncClose(typing.Protocol):
     async def close(self) -> None: ...  # pragma: no cover
 
 
-SupportsAsyncCloseType = typing.TypeVar(
-    "SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False
-)
+SupportsAsyncCloseType = typing.TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
 
 
 class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
index 913fd4c9dbdfad47a51ad850c087503f554d1868..f34e80ead67fb72122bb9227aabe518fe0bf26a2 100644 (file)
@@ -72,21 +72,15 @@ class Starlette:
 
         self.debug = debug
         self.state = State()
-        self.router = Router(
-            routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan
-        )
-        self.exception_handlers = (
-            {} if exception_handlers is None else dict(exception_handlers)
-        )
+        self.router = Router(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)
+        self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)
         self.user_middleware = [] if middleware is None else list(middleware)
         self.middleware_stack: ASGIApp | None = None
 
     def build_middleware_stack(self) -> ASGIApp:
         debug = self.debug
         error_handler = None
-        exception_handlers: dict[
-            typing.Any, typing.Callable[[Request, Exception], Response]
-        ] = {}
+        exception_handlers: dict[typing.Any, typing.Callable[[Request, Exception], Response]] = {}
 
         for key, value in self.exception_handlers.items():
             if key in (500, Exception):
@@ -97,11 +91,7 @@ class Starlette:
         middleware = (
             [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
             + self.user_middleware
-            + [
-                Middleware(
-                    ExceptionMiddleware, handlers=exception_handlers, debug=debug
-                )
-            ]
+            + [Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug)]
         )
 
         app = self.router
@@ -163,9 +153,7 @@ class Starlette:
         name: str | None = None,
         include_in_schema: bool = True,
     ) -> None:  # pragma: no cover
-        self.router.add_route(
-            path, route, methods=methods, name=name, include_in_schema=include_in_schema
-        )
+        self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema)
 
     def add_websocket_route(
         self,
@@ -175,16 +163,14 @@ class Starlette:
     ) -> None:  # pragma: no cover
         self.router.add_websocket_route(path, route, name=name)
 
-    def exception_handler(
-        self, exc_class_or_status_code: int | type[Exception]
-    ) -> typing.Callable:  # type: ignore[type-arg]
+    def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> typing.Callable:  # type: ignore[type-arg]
         warnings.warn(
-            "The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. "  # noqa: E501
-            "Refer to https://www.starlette.io/exceptions/ for the recommended approach.",  # noqa: E501
+            "The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. "
+            "Refer to https://www.starlette.io/exceptions/ for the recommended approach.",
             DeprecationWarning,
         )
 
-        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
+        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]
             self.add_exception_handler(exc_class_or_status_code, func)
             return func
 
@@ -205,12 +191,12 @@ class Starlette:
         >>> app = Starlette(routes=routes)
         """
         warnings.warn(
-            "The `route` decorator is deprecated, and will be removed in version 1.0.0. "  # noqa: E501
-            "Refer to https://www.starlette.io/routing/ for the recommended approach.",  # noqa: E501
+            "The `route` decorator is deprecated, and will be removed in version 1.0.0. "
+            "Refer to https://www.starlette.io/routing/ for the recommended approach.",
             DeprecationWarning,
         )
 
-        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
+        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]
             self.router.add_route(
                 path,
                 func,
@@ -231,18 +217,18 @@ class Starlette:
         >>> app = Starlette(routes=routes)
         """
         warnings.warn(
-            "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. "  # noqa: E501
-            "Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.",  # noqa: E501
+            "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. "
+            "Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
             DeprecationWarning,
         )
 
-        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
+        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]
             self.router.add_websocket_route(path, func, name=name)
             return func
 
         return decorator
 
-    def middleware(self, middleware_type: str) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
+    def middleware(self, middleware_type: str) -> typing.Callable:  # type: ignore[type-arg]
         """
         We no longer document this decorator style API, and its usage is discouraged.
         Instead you should use the following approach:
@@ -251,15 +237,13 @@ class Starlette:
         >>> app = Starlette(middleware=middleware)
         """
         warnings.warn(
-            "The `middleware` decorator is deprecated, and will be removed in version 1.0.0. "  # noqa: E501
-            "Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.",  # noqa: E501
+            "The `middleware` decorator is deprecated, and will be removed in version 1.0.0. "
+            "Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.",
             DeprecationWarning,
         )
-        assert (
-            middleware_type == "http"
-        ), 'Currently only middleware("http") is supported.'
+        assert middleware_type == "http", 'Currently only middleware("http") is supported.'
 
-        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
+        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]
             self.add_middleware(BaseHTTPMiddleware, dispatch=func)
             return func
 
index f2586a042735f9afc011b104f7d6d48c58c0dfce..4fd866412b5e32fc333866cfea8271f3a7116907 100644 (file)
@@ -31,9 +31,7 @@ def requires(
     scopes: str | typing.Sequence[str],
     status_code: int = 403,
     redirect: str | None = None,
-) -> typing.Callable[
-    [typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]
-]:
+) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]:
     scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
 
     def decorator(
@@ -45,17 +43,13 @@ def requires(
                 type_ = parameter.name
                 break
         else:
-            raise Exception(
-                f'No "request" or "websocket" argument on function "{func}"'
-            )
+            raise Exception(f'No "request" or "websocket" argument on function "{func}"')
 
         if type_ == "websocket":
             # Handle websocket functions. (Always async)
             @functools.wraps(func)
             async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
-                websocket = kwargs.get(
-                    "websocket", args[idx] if idx < len(args) else None
-                )
+                websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
                 assert isinstance(websocket, WebSocket)
 
                 if not has_required_scope(websocket, scopes_list):
@@ -107,9 +101,7 @@ class AuthenticationError(Exception):
 
 
 class AuthenticationBackend:
-    async def authenticate(
-        self, conn: HTTPConnection
-    ) -> tuple[AuthCredentials, BaseUser] | None:
+    async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
         raise NotImplementedError()  # pragma: no cover
 
 
index 1cbed3b2202dbe61dbd6cd67c2bb13f5ec385339..0430fc08bb6b256767b8511220e89ae9373fa53f 100644 (file)
@@ -15,9 +15,7 @@ P = ParamSpec("P")
 
 
 class BackgroundTask:
-    def __init__(
-        self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
-    ) -> None:
+    def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
         self.func = func
         self.args = args
         self.kwargs = kwargs
@@ -34,9 +32,7 @@ class BackgroundTasks(BackgroundTask):
     def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None):
         self.tasks = list(tasks) if tasks else []
 
-    def add_task(
-        self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
-    ) -> None:
+    def add_task(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
         task = BackgroundTask(func, *args, **kwargs)
         self.tasks.append(task)
 
index 215e3a63b0b882fd81cd59d137c83e2350a23f50..ce3f5c82b5440453a34cce7abf4b025609f78c0d 100644 (file)
@@ -16,16 +16,15 @@ P = ParamSpec("P")
 T = typing.TypeVar("T")
 
 
-async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None:  # type: ignore[type-arg]  # noqa: E501
+async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None:  # type: ignore[type-arg]
     warnings.warn(
-        "run_until_first_complete is deprecated "
-        "and will be removed in a future version.",
+        "run_until_first_complete is deprecated and will be removed in a future version.",
         DeprecationWarning,
     )
 
     async with anyio.create_task_group() as task_group:
 
-        async def run(func: typing.Callable[[], typing.Coroutine]) -> None:  # type: ignore[type-arg]  # noqa: E501
+        async def run(func: typing.Callable[[], typing.Coroutine]) -> None:  # type: ignore[type-arg]
             await func()
             task_group.cancel_scope.cancel()
 
@@ -33,9 +32,7 @@ async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None:
             task_group.start_soon(run, functools.partial(func, **kwargs))
 
 
-async def run_in_threadpool(
-    func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
-) -> T:
+async def run_in_threadpool(func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
     if kwargs:  # pragma: no cover
         # run_sync doesn't accept 'kwargs', so bind them in here
         func = functools.partial(func, **kwargs)
index 4c3dfe5b0194682d523d2583aaad952f65816748..ca15c564670271e66790000a308a7d7e981e0bac 100644 (file)
@@ -25,18 +25,12 @@ class Environ(typing.MutableMapping[str, str]):
 
     def __setitem__(self, key: str, value: str) -> None:
         if key in self._has_been_read:
-            raise EnvironError(
-                f"Attempting to set environ['{key}'], but the value has already been "
-                "read."
-            )
+            raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been read.")
         self._environ.__setitem__(key, value)
 
     def __delitem__(self, key: str) -> None:
         if key in self._has_been_read:
-            raise EnvironError(
-                f"Attempting to delete environ['{key}'], but the value has already "
-                "been read."
-            )
+            raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.")
         self._environ.__delitem__(key)
 
     def __iter__(self) -> typing.Iterator[str]:
@@ -85,9 +79,7 @@ class Config:
     ) -> T: ...
 
     @typing.overload
-    def __call__(
-        self, key: str, cast: type[str] = ..., default: T = ...
-    ) -> T | str: ...
+    def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ...
 
     def __call__(
         self,
@@ -138,13 +130,9 @@ class Config:
             mapping = {"true": True, "1": True, "false": False, "0": False}
             value = value.lower()
             if value not in mapping:
-                raise ValueError(
-                    f"Config '{key}' has value '{value}'. Not a valid bool."
-                )
+                raise ValueError(f"Config '{key}' has value '{value}'. Not a valid bool.")
             return mapping[value]
         try:
             return cast(value)
         except (TypeError, ValueError):
-            raise ValueError(
-                f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}."
-            )
+            raise ValueError(f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}.")
index 54b5e54f3bf8a19a36a81ad4e5aee970a512153b..90a7296a09a63a6993d652988782d991b2154686 100644 (file)
@@ -108,12 +108,7 @@ class URL:
         return self.scheme in ("https", "wss")
 
     def replace(self, **kwargs: typing.Any) -> URL:
-        if (
-            "username" in kwargs
-            or "password" in kwargs
-            or "hostname" in kwargs
-            or "port" in kwargs
-        ):
+        if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
             hostname = kwargs.pop("hostname", None)
             port = kwargs.pop("port", self.port)
             username = kwargs.pop("username", self.username)
@@ -264,17 +259,12 @@ class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
 
         value: typing.Any = args[0] if args else []
         if kwargs:
-            value = (
-                ImmutableMultiDict(value).multi_items()
-                + ImmutableMultiDict(kwargs).multi_items()
-            )
+            value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
 
         if not value:
             _items: list[tuple[typing.Any, typing.Any]] = []
         elif hasattr(value, "multi_items"):
-            value = typing.cast(
-                ImmutableMultiDict[_KeyType, _CovariantValueType], value
-            )
+            value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
             _items = list(value.multi_items())
         elif hasattr(value, "items"):
             value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
@@ -371,9 +361,7 @@ class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
 
     def update(
         self,
-        *args: MultiDict
-        | typing.Mapping[typing.Any, typing.Any]
-        | list[tuple[typing.Any, typing.Any]],
+        *args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]],
         **kwargs: typing.Any,
     ) -> None:
         value = MultiDict(*args, **kwargs)
@@ -403,9 +391,7 @@ class QueryParams(ImmutableMultiDict[str, str]):
         if isinstance(value, str):
             super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
         elif isinstance(value, bytes):
-            super().__init__(
-                parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs
-            )
+            super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
         else:
             super().__init__(*args, **kwargs)  # type: ignore[arg-type]
         self._list = [(str(k), str(v)) for k, v in self._list]
@@ -490,9 +476,7 @@ class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
 
     def __init__(
         self,
-        *args: FormData
-        | typing.Mapping[str, str | UploadFile]
-        | list[tuple[str, str | UploadFile]],
+        *args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
         **kwargs: str | UploadFile,
     ) -> None:
         super().__init__(*args, **kwargs)
@@ -518,10 +502,7 @@ class Headers(typing.Mapping[str, str]):
         if headers is not None:
             assert raw is None, 'Cannot set both "headers" and "raw".'
             assert scope is None, 'Cannot set both "headers" and "scope".'
-            self._list = [
-                (key.lower().encode("latin-1"), value.encode("latin-1"))
-                for key, value in headers.items()
-            ]
+            self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
         elif raw is not None:
             assert scope is None, 'Cannot set both "raw" and "scope".'
             self._list = raw
@@ -541,18 +522,11 @@ class Headers(typing.Mapping[str, str]):
         return [value.decode("latin-1") for key, value in self._list]
 
     def items(self) -> list[tuple[str, str]]:  # type: ignore[override]
-        return [
-            (key.decode("latin-1"), value.decode("latin-1"))
-            for key, value in self._list
-        ]
+        return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
 
     def getlist(self, key: str) -> list[str]:
         get_header_key = key.lower().encode("latin-1")
-        return [
-            item_value.decode("latin-1")
-            for item_key, item_value in self._list
-            if item_key == get_header_key
-        ]
+        return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
 
     def mutablecopy(self) -> MutableHeaders:
         return MutableHeaders(raw=self._list[:])
index 57f718824eebc1ef749808a9f9dd6417a8e08fb3..eb1dace42df506f8f92cc5b648d3f9246e9b2bcb 100644 (file)
@@ -30,15 +30,9 @@ class HTTPEndpoint:
 
     async def dispatch(self) -> None:
         request = Request(self.scope, receive=self.receive)
-        handler_name = (
-            "get"
-            if request.method == "HEAD" and not hasattr(self, "head")
-            else request.method.lower()
-        )
-
-        handler: typing.Callable[[Request], typing.Any] = getattr(
-            self, handler_name, self.method_not_allowed
-        )
+        handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
+
+        handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed)
         is_async = is_async_callable(handler)
         if is_async:
             response = await handler(request)
@@ -81,9 +75,7 @@ class WebSocketEndpoint:
                     data = await self.decode(websocket, message)
                     await self.on_receive(websocket, data)
                 elif message["type"] == "websocket.disconnect":
-                    close_code = int(
-                        message.get("code") or status.WS_1000_NORMAL_CLOSURE
-                    )
+                    close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE)
                     break
         except Exception as exc:
             close_code = status.WS_1011_INTERNAL_ERROR
@@ -116,9 +108,7 @@ class WebSocketEndpoint:
                 await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
                 raise RuntimeError("Malformed JSON data received.")
 
-        assert (
-            self.encoding is None
-        ), f"Unsupported 'encoding' attribute {self.encoding}"
+        assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}"
         return message["text"] if message.get("text") else message["bytes"]
 
     async def on_connect(self, websocket: WebSocket) -> None:
index 2e12c7faac8fe3837afffc331622e9a4efb0a919..56f63a8be9013453e5d82a084e65ac273eccf33c 100644 (file)
@@ -46,12 +46,8 @@ class MultiPartException(Exception):
 
 
 class FormParser:
-    def __init__(
-        self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
-    ) -> None:
-        assert (
-            multipart is not None
-        ), "The `python-multipart` library must be installed to use form parsing."
+    def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None:
+        assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
         self.headers = headers
         self.stream = stream
         self.messages: list[tuple[FormMessage, bytes]] = []
@@ -128,9 +124,7 @@ class MultiPartParser:
         max_files: int | float = 1000,
         max_fields: int | float = 1000,
     ) -> None:
-        assert (
-            multipart is not None
-        ), "The `python-multipart` library must be installed to use form parsing."
+        assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
         self.headers = headers
         self.stream = stream
         self.max_files = max_files
@@ -181,30 +175,20 @@ class MultiPartParser:
         field = self._current_partial_header_name.lower()
         if field == b"content-disposition":
             self._current_part.content_disposition = self._current_partial_header_value
-        self._current_part.item_headers.append(
-            (field, self._current_partial_header_value)
-        )
+        self._current_part.item_headers.append((field, self._current_partial_header_value))
         self._current_partial_header_name = b""
         self._current_partial_header_value = b""
 
     def on_headers_finished(self) -> None:
-        disposition, options = parse_options_header(
-            self._current_part.content_disposition
-        )
+        disposition, options = parse_options_header(self._current_part.content_disposition)
         try:
-            self._current_part.field_name = _user_safe_decode(
-                options[b"name"], self._charset
-            )
+            self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
         except KeyError:
-            raise MultiPartException(
-                'The Content-Disposition header field "name" must be ' "provided."
-            )
+            raise MultiPartException('The Content-Disposition header field "name" must be provided.')
         if b"filename" in options:
             self._current_files += 1
             if self._current_files > self.max_files:
-                raise MultiPartException(
-                    f"Too many files. Maximum number of files is {self.max_files}."
-                )
+                raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
             filename = _user_safe_decode(options[b"filename"], self._charset)
             tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
             self._files_to_close_on_error.append(tempfile)
@@ -217,9 +201,7 @@ class MultiPartParser:
         else:
             self._current_fields += 1
             if self._current_fields > self.max_fields:
-                raise MultiPartException(
-                    f"Too many fields. Maximum number of fields is {self.max_fields}."
-                )
+                raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.")
             self._current_part.file = None
 
     def on_end(self) -> None:
index d9e64f574e0d94768f2e142ff9e1b26e828b1db1..8566aac08a6d1bce4d8fd3fdc7457f56f1852f11 100644 (file)
@@ -14,13 +14,9 @@ P = ParamSpec("P")
 
 
 class _MiddlewareClass(Protocol[P]):
-    def __init__(
-        self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs
-    ) -> None: ...  # pragma: no cover
+    def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None: ...  # pragma: no cover
 
-    async def __call__(
-        self, scope: Scope, receive: Receive, send: Send
-    ) -> None: ...  # pragma: no cover
+    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ...  # pragma: no cover
 
 
 class Middleware:
index 966c639bb6360d3047f254f3c8a4142dada043cf..8555ee0780e98b052eb463d55a1c18e39b257762 100644 (file)
@@ -18,14 +18,13 @@ class AuthenticationMiddleware:
         self,
         app: ASGIApp,
         backend: AuthenticationBackend,
-        on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response]
-        | None = None,
+        on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
     ) -> None:
         self.app = app
         self.backend = backend
-        self.on_error: typing.Callable[
-            [HTTPConnection, AuthenticationError], Response
-        ] = on_error if on_error is not None else self.default_on_error
+        self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = (
+            on_error if on_error is not None else self.default_on_error
+        )
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
         if scope["type"] not in ["http", "websocket"]:
index 87c0f51f8ba7cf5704006361c1c9e79b2fba6083..2ac6f7f7f0afd4e5c2e7d1b5f226536262b433f6 100644 (file)
@@ -11,9 +11,7 @@ from starlette.responses import AsyncContentStream, Response
 from starlette.types import ASGIApp, Message, Receive, Scope, Send
 
 RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
-DispatchFunction = typing.Callable[
-    [Request, RequestResponseEndpoint], typing.Awaitable[Response]
-]
+DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]]
 T = typing.TypeVar("T")
 
 
@@ -180,9 +178,7 @@ class BaseHTTPMiddleware:
                 if app_exc is not None:
                     raise app_exc
 
-            response = _StreamingResponse(
-                status_code=message["status"], content=body_stream(), info=info
-            )
+            response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
             response.raw_headers = message["headers"]
             return response
 
@@ -192,9 +188,7 @@ class BaseHTTPMiddleware:
                 await response(scope, wrapped_receive, send)
                 response_sent.set()
 
-    async def dispatch(
-        self, request: Request, call_next: RequestResponseEndpoint
-    ) -> Response:
+    async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
         raise NotImplementedError()  # pragma: no cover
 
 
index 4b8e97bc9dc844e9b36e31f06abd7840d2e58de0..61502691abdcde4ce790cb16c4d28002a6241311 100644 (file)
@@ -96,9 +96,7 @@ class CORSMiddleware:
         if self.allow_all_origins:
             return True
 
-        if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
-            origin
-        ):
+        if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
             return True
 
         return origin in self.allow_origins
@@ -141,15 +139,11 @@ class CORSMiddleware:
 
         return PlainTextResponse("OK", status_code=200, headers=headers)
 
-    async def simple_response(
-        self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
-    ) -> None:
+    async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
         send = functools.partial(self.send, send=send, request_headers=request_headers)
         await self.app(scope, receive, send)
 
-    async def send(
-        self, message: Message, send: Send, request_headers: Headers
-    ) -> None:
+    async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
         if message["type"] != "http.response.start":
             await send(message)
             return
index 3fc4a44024e5aa62e037cba391ecdbc82d201ef4..76ad776be2272b6ee6ad1e5cd948762465bf0dcc 100644 (file)
@@ -186,9 +186,7 @@ class ServerErrorMiddleware:
             # to optionally raise the error within the test case.
             raise exc
 
-    def format_line(
-        self, index: int, line: str, frame_lineno: int, frame_index: int
-    ) -> str:
+    def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
         values = {
             # HTML escape - line could contain < or >
             "line": html.escape(line).replace(" ", "&nbsp"),
@@ -225,9 +223,7 @@ class ServerErrorMiddleware:
         return FRAME_TEMPLATE.format(**values)
 
     def generate_html(self, exc: Exception, limit: int = 7) -> str:
-        traceback_obj = traceback.TracebackException.from_exception(
-            exc, capture_locals=True
-        )
+        traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
 
         exc_html = ""
         is_collapsed = False
index b2bf88dbfe7bc40492fd5bbce4d85de6c431d806..d708929e3bae38a055885aa9dd4573b25c87aa3e 100644 (file)
@@ -18,10 +18,7 @@ class ExceptionMiddleware:
     def __init__(
         self,
         app: ASGIApp,
-        handlers: typing.Mapping[
-            typing.Any, typing.Callable[[Request, Exception], Response]
-        ]
-        | None = None,
+        handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None,
         debug: bool = False,
     ) -> None:
         self.app = app
@@ -68,9 +65,7 @@ class ExceptionMiddleware:
         assert isinstance(exc, HTTPException)
         if exc.status_code in {204, 304}:
             return Response(status_code=exc.status_code, headers=exc.headers)
-        return PlainTextResponse(
-            exc.detail, status_code=exc.status_code, headers=exc.headers
-        )
+        return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
 
     async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
         assert isinstance(exc, WebSocketException)
index 0579e0410acc7e113f3cc176c3d5b0a31db25121..127b91e7a77b88b8ad1fc12b5ebad8ae46bc1aea 100644 (file)
@@ -7,9 +7,7 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send
 
 
 class GZipMiddleware:
-    def __init__(
-        self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
-    ) -> None:
+    def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
         self.app = app
         self.minimum_size = minimum_size
         self.compresslevel = compresslevel
@@ -18,9 +16,7 @@ class GZipMiddleware:
         if scope["type"] == "http":
             headers = Headers(scope=scope)
             if "gzip" in headers.get("Accept-Encoding", ""):
-                responder = GZipResponder(
-                    self.app, self.minimum_size, compresslevel=self.compresslevel
-                )
+                responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
                 await responder(scope, receive, send)
                 return
         await self.app(scope, receive, send)
@@ -35,9 +31,7 @@ class GZipResponder:
         self.started = False
         self.content_encoding_set = False
         self.gzip_buffer = io.BytesIO()
-        self.gzip_file = gzip.GzipFile(
-            mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel
-        )
+        self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
         self.send = send
index 5855912cac339b1f57bfa1be310b2bf57128c508..5f9fcd883b69fb960b067f316d49401aba425b4c 100644 (file)
@@ -61,7 +61,7 @@ class SessionMiddleware:
                     data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
                     data = self.signer.sign(data)
                     headers = MutableHeaders(scope=message)
-                    header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(  # noqa E501
+                    header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
                         session_cookie=self.session_cookie,
                         data=data.decode("utf-8"),
                         path=self.path,
@@ -72,7 +72,7 @@ class SessionMiddleware:
                 elif not initial_session_was_empty:
                     # The session has been cleared.
                     headers = MutableHeaders(scope=message)
-                    header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(  # noqa E501
+                    header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
                         session_cookie=self.session_cookie,
                         data="null",
                         path=self.path,
index 59e527363348272320b761950bfd3aadf0076c92..2d1c999e25f40929aacd526ad8b737bbd6394e1f 100644 (file)
@@ -41,9 +41,7 @@ class TrustedHostMiddleware:
         is_valid_host = False
         found_www_redirect = False
         for pattern in self.allowed_hosts:
-            if host == pattern or (
-                pattern.startswith("*") and host.endswith(pattern[1:])
-            ):
+            if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
                 is_valid_host = True
                 break
             elif "www." + host == pattern:
index c9a7e132814684916bc013aad65f8b826d0f685c..71f4ab5de93cf36c160f66623f75632a431cbfd0 100644 (file)
@@ -89,9 +89,7 @@ class WSGIResponder:
         self.scope = scope
         self.status = None
         self.response_headers = None
-        self.stream_send, self.stream_receive = anyio.create_memory_object_stream(
-            math.inf
-        )
+        self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
         self.response_started = False
         self.exc_info: typing.Any = None
 
@@ -151,6 +149,4 @@ class WSGIResponder:
                 {"type": "http.response.body", "body": chunk, "more_body": True},
             )
 
-        anyio.from_thread.run(
-            self.stream_send.send, {"type": "http.response.body", "body": b""}
-        )
+        anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})
index a2fdfd81eda6a308c00cb2b36db8af8dc77e6e2e..23f8ac70a00d294237a55f4ba90f057e5360857b 100644 (file)
@@ -104,9 +104,7 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
             # This is used by request.url_for, it might be used inside a Mount which
             # would have its own child scope with its own root_path, but the base URL
             # for url_for should still be the top level app root path.
-            app_root_path = base_url_scope.get(
-                "app_root_path", base_url_scope.get("root_path", "")
-            )
+            app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
             path = app_root_path
             if not path.endswith("/"):
                 path += "/"
@@ -153,23 +151,17 @@ class HTTPConnection(typing.Mapping[str, typing.Any]):
 
     @property
     def session(self) -> dict[str, typing.Any]:
-        assert (
-            "session" in self.scope
-        ), "SessionMiddleware must be installed to access request.session"
+        assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
         return self.scope["session"]  # type: ignore[no-any-return]
 
     @property
     def auth(self) -> typing.Any:
-        assert (
-            "auth" in self.scope
-        ), "AuthenticationMiddleware must be installed to access request.auth"
+        assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
         return self.scope["auth"]
 
     @property
     def user(self) -> typing.Any:
-        assert (
-            "user" in self.scope
-        ), "AuthenticationMiddleware must be installed to access request.user"
+        assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
         return self.scope["user"]
 
     @property
@@ -199,9 +191,7 @@ async def empty_send(message: Message) -> typing.NoReturn:
 class Request(HTTPConnection):
     _form: FormData | None
 
-    def __init__(
-        self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
-    ):
+    def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
         super().__init__(scope)
         assert scope["type"] == "http"
         self._receive = receive
@@ -252,9 +242,7 @@ class Request(HTTPConnection):
             self._json = json.loads(body)
         return self._json
 
-    async def _get_form(
-        self, *, max_files: int | float = 1000, max_fields: int | float = 1000
-    ) -> FormData:
+    async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | float = 1000) -> FormData:
         if self._form is None:
             assert (
                 parse_options_header is not None
@@ -285,9 +273,7 @@ class Request(HTTPConnection):
     def form(
         self, *, max_files: int | float = 1000, max_fields: int | float = 1000
     ) -> AwaitableOrContextManager[FormData]:
-        return AwaitableOrContextManagerWrapper(
-            self._get_form(max_files=max_files, max_fields=max_fields)
-        )
+        return AwaitableOrContextManagerWrapper(self._get_form(max_files=max_files, max_fields=max_fields))
 
     async def close(self) -> None:
         if self._form is not None:
@@ -312,9 +298,5 @@ class Request(HTTPConnection):
             raw_headers: list[tuple[bytes, bytes]] = []
             for name in SERVER_PUSH_HEADERS_TO_COPY:
                 for value in self.headers.getlist(name):
-                    raw_headers.append(
-                        (name.encode("latin-1"), value.encode("latin-1"))
-                    )
-            await self._send(
-                {"type": "http.response.push", "path": path, "headers": raw_headers}
-            )
+                    raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
+            await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})
index 4f15404ca1bed7d3c231a1fcacc9ef04306cdba9..06d6ce5ca4319dcbc0d751b0625062ad0b13d086 100644 (file)
@@ -54,10 +54,7 @@ class Response:
             populate_content_length = True
             populate_content_type = True
         else:
-            raw_headers = [
-                (k.lower().encode("latin-1"), v.encode("latin-1"))
-                for k, v in headers.items()
-            ]
+            raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()]
             keys = [h[0] for h in raw_headers]
             populate_content_length = b"content-length" not in keys
             populate_content_type = b"content-type" not in keys
@@ -73,10 +70,7 @@ class Response:
 
         content_type = self.media_type
         if content_type is not None and populate_content_type:
-            if (
-                content_type.startswith("text/")
-                and "charset=" not in content_type.lower()
-            ):
+            if content_type.startswith("text/") and "charset=" not in content_type.lower():
                 content_type += "; charset=" + self.charset
             raw_headers.append((b"content-type", content_type.encode("latin-1")))
 
@@ -201,9 +195,7 @@ class RedirectResponse(Response):
         headers: typing.Mapping[str, str] | None = None,
         background: BackgroundTask | None = None,
     ) -> None:
-        super().__init__(
-            content=b"", status_code=status_code, headers=headers, background=background
-        )
+        super().__init__(content=b"", status_code=status_code, headers=headers, background=background)
         self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
 
 
@@ -299,11 +291,9 @@ class FileResponse(Response):
         if self.filename is not None:
             content_disposition_filename = quote(self.filename)
             if content_disposition_filename != self.filename:
-                content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}"  # noqa: E501
+                content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}"
             else:
-                content_disposition = (
-                    f'{content_disposition_type}; filename="{self.filename}"'
-                )
+                content_disposition = f'{content_disposition_type}; filename="{self.filename}"'
             self.headers.setdefault("content-disposition", content_disposition)
         self.stat_result = stat_result
         if stat_result is not None:
index 481b13f5d91950d23c5762bb5463326b15840102..cde771563b35709ab0dce9b1b4c99802fa916966 100644 (file)
@@ -47,8 +47,7 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool:  # pragma: no cover
     including those wrapped in functools.partial objects.
     """
     warnings.warn(
-        "iscoroutinefunction_or_partial is deprecated, "
-        "and will be removed in a future release.",
+        "iscoroutinefunction_or_partial is deprecated, and will be removed in a future release.",
         DeprecationWarning,
     )
     while isinstance(obj, functools.partial):
@@ -143,9 +142,7 @@ def compile_path(
     for match in PARAM_REGEX.finditer(path):
         param_name, convertor_type = match.groups("str")
         convertor_type = convertor_type.lstrip(":")
-        assert (
-            convertor_type in CONVERTOR_TYPES
-        ), f"Unknown path convertor '{convertor_type}'"
+        assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'"
         convertor = CONVERTOR_TYPES[convertor_type]
 
         path_regex += re.escape(path[idx : match.start()])
@@ -275,9 +272,7 @@ class Route(BaseRoute):
         if name != self.name or seen_params != expected_params:
             raise NoMatchFound(name, path_params)
 
-        path, remaining_params = replace_params(
-            self.path_format, self.param_convertors, path_params
-        )
+        path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
         assert not remaining_params
         return URLPath(path=path, protocol="http")
 
@@ -287,9 +282,7 @@ class Route(BaseRoute):
             if "app" in scope:
                 raise HTTPException(status_code=405, headers=headers)
             else:
-                response = PlainTextResponse(
-                    "Method Not Allowed", status_code=405, headers=headers
-                )
+                response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
             await response(scope, receive, send)
         else:
             await self.app(scope, receive, send)
@@ -361,9 +354,7 @@ class WebSocketRoute(BaseRoute):
         if name != self.name or seen_params != expected_params:
             raise NoMatchFound(name, path_params)
 
-        path, remaining_params = replace_params(
-            self.path_format, self.param_convertors, path_params
-        )
+        path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
         assert not remaining_params
         return URLPath(path=path, protocol="websocket")
 
@@ -371,11 +362,7 @@ class WebSocketRoute(BaseRoute):
         await self.app(scope, receive, send)
 
     def __eq__(self, other: typing.Any) -> bool:
-        return (
-            isinstance(other, WebSocketRoute)
-            and self.path == other.path
-            and self.endpoint == other.endpoint
-        )
+        return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint
 
     def __repr__(self) -> str:
         return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})"
@@ -392,9 +379,7 @@ class Mount(BaseRoute):
         middleware: typing.Sequence[Middleware] | None = None,
     ) -> None:
         assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
-        assert (
-            app is not None or routes is not None
-        ), "Either 'app=...', or 'routes=' must be specified"
+        assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified"
         self.path = path.rstrip("/")
         if app is not None:
             self._base_app: ASGIApp = app
@@ -405,9 +390,7 @@ class Mount(BaseRoute):
             for cls, args, kwargs in reversed(middleware):
                 self.app = cls(app=self.app, *args, **kwargs)
         self.name = name
-        self.path_regex, self.path_format, self.param_convertors = compile_path(
-            self.path + "/{path:path}"
-        )
+        self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")
 
     @property
     def routes(self) -> list[BaseRoute]:
@@ -450,9 +433,7 @@ class Mount(BaseRoute):
         if self.name is not None and name == self.name and "path" in path_params:
             # 'name' matches "<mount_name>".
             path_params["path"] = path_params["path"].lstrip("/")
-            path, remaining_params = replace_params(
-                self.path_format, self.param_convertors, path_params
-            )
+            path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
             if not remaining_params:
                 return URLPath(path=path)
         elif self.name is None or name.startswith(self.name + ":"):
@@ -464,17 +445,13 @@ class Mount(BaseRoute):
                 remaining_name = name[len(self.name) + 1 :]
             path_kwarg = path_params.get("path")
             path_params["path"] = ""
-            path_prefix, remaining_params = replace_params(
-                self.path_format, self.param_convertors, path_params
-            )
+            path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
             if path_kwarg is not None:
                 remaining_params["path"] = path_kwarg
             for route in self.routes or []:
                 try:
                     url = route.url_path_for(remaining_name, **remaining_params)
-                    return URLPath(
-                        path=path_prefix.rstrip("/") + str(url), protocol=url.protocol
-                    )
+                    return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol)
                 except NoMatchFound:
                     pass
         raise NoMatchFound(name, path_params)
@@ -483,11 +460,7 @@ class Mount(BaseRoute):
         await self.app(scope, receive, send)
 
     def __eq__(self, other: typing.Any) -> bool:
-        return (
-            isinstance(other, Mount)
-            and self.path == other.path
-            and self.app == other.app
-        )
+        return isinstance(other, Mount) and self.path == other.path and self.app == other.app
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
@@ -526,9 +499,7 @@ class Host(BaseRoute):
         if self.name is not None and name == self.name and "path" in path_params:
             # 'name' matches "<mount_name>".
             path = path_params.pop("path")
-            host, remaining_params = replace_params(
-                self.host_format, self.param_convertors, path_params
-            )
+            host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
             if not remaining_params:
                 return URLPath(path=path, host=host)
         elif self.name is None or name.startswith(self.name + ":"):
@@ -538,9 +509,7 @@ class Host(BaseRoute):
             else:
                 # 'name' matches "<mount_name>:<child_name>".
                 remaining_name = name[len(self.name) + 1 :]
-            host, remaining_params = replace_params(
-                self.host_format, self.param_convertors, path_params
-            )
+            host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
             for route in self.routes or []:
                 try:
                     url = route.url_path_for(remaining_name, **remaining_params)
@@ -553,11 +522,7 @@ class Host(BaseRoute):
         await self.app(scope, receive, send)
 
     def __eq__(self, other: typing.Any) -> bool:
-        return (
-            isinstance(other, Host)
-            and self.host == other.host
-            and self.app == other.app
-        )
+        return isinstance(other, Host) and self.host == other.host and self.app == other.app
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
@@ -585,9 +550,7 @@ class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
 
 
 def _wrap_gen_lifespan_context(
-    lifespan_context: typing.Callable[
-        [typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]
-    ],
+    lifespan_context: typing.Callable[[typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]],
 ) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]:
     cmgr = contextlib.contextmanager(lifespan_context)
 
@@ -730,9 +693,7 @@ class Router:
             async with self.lifespan_context(app) as maybe_state:
                 if maybe_state is not None:
                     if "state" not in scope:
-                        raise RuntimeError(
-                            'The server does not support "state" in the lifespan scope.'
-                        )
+                        raise RuntimeError('The server does not support "state" in the lifespan scope.')
                     scope["state"].update(maybe_state)
                 await send({"type": "lifespan.startup.complete"})
                 started = True
@@ -806,15 +767,11 @@ class Router:
     def __eq__(self, other: typing.Any) -> bool:
         return isinstance(other, Router) and self.routes == other.routes
 
-    def mount(
-        self, path: str, app: ASGIApp, name: str | None = None
-    ) -> None:  # pragma: nocover
+    def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:  # pragma: nocover
         route = Mount(path, app=app, name=name)
         self.routes.append(route)
 
-    def host(
-        self, host: str, app: ASGIApp, name: str | None = None
-    ) -> None:  # pragma: no cover
+    def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:  # pragma: no cover
         route = Host(host, app=app, name=name)
         self.routes.append(route)
 
@@ -860,11 +817,11 @@ class Router:
         """
         warnings.warn(
             "The `route` decorator is deprecated, and will be removed in version 1.0.0."
-            "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.",  # noqa: E501
+            "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.",
             DeprecationWarning,
         )
 
-        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
+        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]
             self.add_route(
                 path,
                 func,
@@ -885,20 +842,18 @@ class Router:
         >>> app = Starlette(routes=routes)
         """
         warnings.warn(
-            "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to "  # noqa: E501
-            "https://www.starlette.io/routing/#websocket-routing for the recommended approach.",  # noqa: E501
+            "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to "
+            "https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
             DeprecationWarning,
         )
 
-        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
+        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]
             self.add_websocket_route(path, func, name=name)
             return func
 
         return decorator
 
-    def add_event_handler(
-        self, event_type: str, func: typing.Callable[[], typing.Any]
-    ) -> None:  # pragma: no cover
+    def add_event_handler(self, event_type: str, func: typing.Callable[[], typing.Any]) -> None:  # pragma: no cover
         assert event_type in ("startup", "shutdown")
 
         if event_type == "startup":
@@ -908,12 +863,12 @@ class Router:
 
     def on_event(self, event_type: str) -> typing.Callable:  # type: ignore[type-arg]
         warnings.warn(
-            "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. "  # noqa: E501
+            "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. "
             "Refer to https://www.starlette.io/lifespan/ for recommended approach.",
             DeprecationWarning,
         )
 
-        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]  # noqa: E501
+        def decorator(func: typing.Callable) -> typing.Callable:  # type: ignore[type-arg]
             self.add_event_handler(event_type, func)
             return func
 
index 89fa20b89a12a280cecac87c4eb89992398f23cf..688fd85bed2d25d56c0b98fa89f2d03cd05f46fe 100644 (file)
@@ -19,9 +19,7 @@ class OpenAPIResponse(Response):
 
     def render(self, content: typing.Any) -> bytes:
         assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
-        assert isinstance(
-            content, dict
-        ), "The schema passed to OpenAPIResponse should be a dictionary."
+        assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."
         return yaml.dump(content, default_flow_style=False).encode("utf-8")
 
 
@@ -73,9 +71,7 @@ class BaseSchemaGenerator:
                 for method in route.methods or ["GET"]:
                     if method == "HEAD":
                         continue
-                    endpoints_info.append(
-                        EndpointInfo(path, method.lower(), route.endpoint)
-                    )
+                    endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint))
             else:
                 path = self._remove_converter(route.path)
                 for method in ["get", "post", "put", "patch", "delete", "options"]:
@@ -95,9 +91,7 @@ class BaseSchemaGenerator:
         """
         return re.sub(r":\w+}", "}", path)
 
-    def parse_docstring(
-        self, func_or_method: typing.Callable[..., typing.Any]
-    ) -> dict[str, typing.Any]:
+    def parse_docstring(self, func_or_method: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]:
         """
         Given a function, parse the docstring as YAML and return a dictionary of info.
         """
index afb09b56b24e4c1dbed401d04842e1adc052f06d..7498c30112134f655f7ae1c9a1079cb15552dbcf 100644 (file)
@@ -32,11 +32,7 @@ class NotModifiedResponse(Response):
     def __init__(self, headers: Headers):
         super().__init__(
             status_code=304,
-            headers={
-                name: value
-                for name, value in headers.items()
-                if name in self.NOT_MODIFIED_HEADERS
-            },
+            headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS},
         )
 
 
@@ -80,9 +76,7 @@ class StaticFiles:
             spec = importlib.util.find_spec(package)
             assert spec is not None, f"Package {package!r} could not be found."
             assert spec.origin is not None, f"Package {package!r} could not be found."
-            package_directory = os.path.normpath(
-                os.path.join(spec.origin, "..", statics_dir)
-            )
+            package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir))
             assert os.path.isdir(
                 package_directory
             ), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
@@ -110,7 +104,7 @@ class StaticFiles:
         with OS specific path separators, and any '..', '.' components removed.
         """
         route_path = get_route_path(scope)
-        return os.path.normpath(os.path.join(*route_path.split("/")))  # noqa: E501
+        return os.path.normpath(os.path.join(*route_path.split("/")))
 
     async def get_response(self, path: str, scope: Scope) -> Response:
         """
@@ -120,9 +114,7 @@ class StaticFiles:
             raise HTTPException(status_code=405)
 
         try:
-            full_path, stat_result = await anyio.to_thread.run_sync(
-                self.lookup_path, path
-            )
+            full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path)
         except PermissionError:
             raise HTTPException(status_code=401)
         except OSError as exc:
@@ -140,9 +132,7 @@ class StaticFiles:
             # We're in HTML mode, and have got a directory URL.
             # Check if we have 'index.html' file to serve.
             index_path = os.path.join(path, "index.html")
-            full_path, stat_result = await anyio.to_thread.run_sync(
-                self.lookup_path, index_path
-            )
+            full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path)
             if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
                 if not scope["path"].endswith("/"):
                     # Directory URLs should redirect to always end in "/".
@@ -153,9 +143,7 @@ class StaticFiles:
 
         if self.html:
             # Check for '404.html' if we're in HTML mode.
-            full_path, stat_result = await anyio.to_thread.run_sync(
-                self.lookup_path, "404.html"
-            )
+            full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html")
             if stat_result and stat.S_ISREG(stat_result.st_mode):
                 return FileResponse(full_path, stat_result=stat_result, status_code=404)
         raise HTTPException(status_code=404)
@@ -187,9 +175,7 @@ class StaticFiles:
     ) -> Response:
         request_headers = Headers(scope=scope)
 
-        response = FileResponse(
-            full_path, status_code=status_code, stat_result=stat_result
-        )
+        response = FileResponse(full_path, status_code=status_code, stat_result=stat_result)
         if self.is_not_modified(response.headers, request_headers):
             return NotModifiedResponse(response.headers)
         return response
@@ -206,17 +192,11 @@ class StaticFiles:
         try:
             stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
         except FileNotFoundError:
-            raise RuntimeError(
-                f"StaticFiles directory '{self.directory}' does not exist."
-            )
+            raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.")
         if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
-            raise RuntimeError(
-                f"StaticFiles path '{self.directory}' is not a directory."
-            )
+            raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.")
 
-    def is_not_modified(
-        self, response_headers: Headers, request_headers: Headers
-    ) -> bool:
+    def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool:
         """
         Given the request and response headers, return `True` if an HTTP
         "Not Modified" response could be returned instead.
@@ -232,11 +212,7 @@ class StaticFiles:
         try:
             if_modified_since = parsedate(request_headers["if-modified-since"])
             last_modified = parsedate(response_headers["last-modified"])
-            if (
-                if_modified_since is not None
-                and last_modified is not None
-                and if_modified_since >= last_modified
-            ):
+            if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified:
                 return True
         except KeyError:
             pass
index aae2cbe24684f2fe9fe52853b31c16c5f3add3b2..48e54c0cc782a3c172f187dcd3a263be636246c5 100644 (file)
@@ -68,8 +68,7 @@ class Jinja2Templates:
         self,
         directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]],
         *,
-        context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
-        | None = None,
+        context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
         **env_options: typing.Any,
     ) -> None: ...
 
@@ -78,31 +77,24 @@ class Jinja2Templates:
         self,
         *,
         env: jinja2.Environment,
-        context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
-        | None = None,
+        context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
     ) -> None: ...
 
     def __init__(
         self,
-        directory: str
-        | PathLike[str]
-        | typing.Sequence[str | PathLike[str]]
-        | None = None,
+        directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]] | None = None,
         *,
-        context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
-        | None = None,
+        context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
         env: jinja2.Environment | None = None,
         **env_options: typing.Any,
     ) -> None:
         if env_options:
             warnings.warn(
-                "Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.",  # noqa: E501
+                "Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.",
                 DeprecationWarning,
             )
         assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
-        assert bool(directory) ^ bool(
-            env
-        ), "either 'directory' or 'env' arguments must be passed"
+        assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed"
         self.context_processors = context_processors or []
         if directory is not None:
             self.env = self._create_env(directory, **env_options)
@@ -163,25 +155,19 @@ class Jinja2Templates:
         # Deprecated usage
         ...
 
-    def TemplateResponse(
-        self, *args: typing.Any, **kwargs: typing.Any
-    ) -> _TemplateResponse:
+    def TemplateResponse(self, *args: typing.Any, **kwargs: typing.Any) -> _TemplateResponse:
         if args:
-            if isinstance(
-                args[0], str
-            ):  # the first argument is template name (old style)
+            if isinstance(args[0], str):  # the first argument is template name (old style)
                 warnings.warn(
                     "The `name` is not the first parameter anymore. "
                     "The first parameter should be the `Request` instance.\n"
-                    'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.',  # noqa: E501
+                    'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.',
                     DeprecationWarning,
                 )
 
                 name = args[0]
                 context = args[1] if len(args) > 1 else kwargs.get("context", {})
-                status_code = (
-                    args[2] if len(args) > 2 else kwargs.get("status_code", 200)
-                )
+                status_code = args[2] if len(args) > 2 else kwargs.get("status_code", 200)
                 headers = args[2] if len(args) > 2 else kwargs.get("headers")
                 media_type = args[3] if len(args) > 3 else kwargs.get("media_type")
                 background = args[4] if len(args) > 4 else kwargs.get("background")
@@ -193,9 +179,7 @@ class Jinja2Templates:
                 request = args[0]
                 name = args[1] if len(args) > 1 else kwargs["name"]
                 context = args[2] if len(args) > 2 else kwargs.get("context", {})
-                status_code = (
-                    args[3] if len(args) > 3 else kwargs.get("status_code", 200)
-                )
+                status_code = args[3] if len(args) > 3 else kwargs.get("status_code", 200)
                 headers = args[4] if len(args) > 4 else kwargs.get("headers")
                 media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
                 background = args[6] if len(args) > 6 else kwargs.get("background")
@@ -203,7 +187,7 @@ class Jinja2Templates:
             if "request" not in kwargs:
                 warnings.warn(
                     "The `TemplateResponse` now requires the `request` argument.\n"
-                    'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.',  # noqa: E501
+                    'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.',
                     DeprecationWarning,
                 )
                 if "request" not in kwargs.get("context", {}):
index bf928d23f1f720f86a98a5cba85b211d81854e06..fcf392e334eaca069f16710267f72a567ed421f1 100644 (file)
@@ -37,9 +37,7 @@ except ModuleNotFoundError:  # pragma: no cover
         "You can install this with:\n"
         "    $ pip install httpx\n"
     )
-_PortalFactoryType = typing.Callable[
-    [], typing.ContextManager[anyio.abc.BlockingPortal]
-]
+_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]]
 
 ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
 ASGI2App = typing.Callable[[Scope], ASGIInstance]
@@ -169,9 +167,7 @@ class WebSocketTestSession:
 
     def _raise_on_close(self, message: Message) -> None:
         if message["type"] == "websocket.close":
-            raise WebSocketDisconnect(
-                code=message.get("code", 1000), reason=message.get("reason", "")
-            )
+            raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
         elif message["type"] == "websocket.http.response.start":
             status_code: int = message["status"]
             headers: list[tuple[bytes, bytes]] = message["headers"]
@@ -199,9 +195,7 @@ class WebSocketTestSession:
     def send_bytes(self, data: bytes) -> None:
         self.send({"type": "websocket.receive", "bytes": data})
 
-    def send_json(
-        self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text"
-    ) -> None:
+    def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None:
         text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
         if mode == "text":
             self.send({"type": "websocket.receive", "text": text})
@@ -227,9 +221,7 @@ class WebSocketTestSession:
         self._raise_on_close(message)
         return typing.cast(bytes, message["bytes"])
 
-    def receive_json(
-        self, mode: typing.Literal["text", "binary"] = "text"
-    ) -> typing.Any:
+    def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any:
         message = self.receive()
         self._raise_on_close(message)
         if mode == "text":
@@ -280,10 +272,7 @@ class _TestClientTransport(httpx.BaseTransport):
             headers = [(b"host", (f"{host}:{port}").encode())]
 
         # Include other request headers.
-        headers += [
-            (key.lower().encode(), value.encode())
-            for key, value in request.headers.multi_items()
-        ]
+        headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
 
         scope: dict[str, typing.Any]
 
@@ -365,22 +354,13 @@ class _TestClientTransport(httpx.BaseTransport):
             nonlocal raw_kwargs, response_started, template, context
 
             if message["type"] == "http.response.start":
-                assert (
-                    not response_started
-                ), 'Received multiple "http.response.start" messages.'
+                assert not response_started, 'Received multiple "http.response.start" messages.'
                 raw_kwargs["status_code"] = message["status"]
-                raw_kwargs["headers"] = [
-                    (key.decode(), value.decode())
-                    for key, value in message.get("headers", [])
-                ]
+                raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
                 response_started = True
             elif message["type"] == "http.response.body":
-                assert (
-                    response_started
-                ), 'Received "http.response.body" without "http.response.start".'
-                assert (
-                    not response_complete.is_set()
-                ), 'Received "http.response.body" after response completed.'
+                assert response_started, 'Received "http.response.body" without "http.response.start".'
+                assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
                 body = message.get("body", b"")
                 more_body = message.get("more_body", False)
                 if request.method != "HEAD":
@@ -435,9 +415,7 @@ class TestClient(httpx.Client):
         headers: dict[str, str] | None = None,
         follow_redirects: bool = True,
     ) -> None:
-        self.async_backend = _AsyncBackend(
-            backend=backend, backend_options=backend_options or {}
-        )
+        self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
         if _is_asgi3(app):
             asgi_app = app
         else:
@@ -468,22 +446,15 @@ class TestClient(httpx.Client):
         if self.portal is not None:
             yield self.portal
         else:
-            with anyio.from_thread.start_blocking_portal(
-                **self.async_backend
-            ) as portal:
+            with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
                 yield portal
 
     def _choose_redirect_arg(
         self, follow_redirects: bool | None, allow_redirects: bool | None
     ) -> bool | httpx._client.UseClientDefault:
-        redirect: bool | httpx._client.UseClientDefault = (
-            httpx._client.USE_CLIENT_DEFAULT
-        )
+        redirect: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT
         if allow_redirects is not None:
-            message = (
-                "The `allow_redirects` argument is deprecated. "
-                "Use `follow_redirects` instead."
-            )
+            message = "The `allow_redirects` argument is deprecated. Use `follow_redirects` instead."
             warnings.warn(message, DeprecationWarning)
             redirect = allow_redirects
         if follow_redirects is not None:
@@ -506,12 +477,10 @@ class TestClient(httpx.Client):
         params: httpx._types.QueryParamTypes | None = None,
         headers: httpx._types.HeaderTypes | None = None,
         cookies: httpx._types.CookieTypes | None = None,
-        auth: httpx._types.AuthTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         follow_redirects: bool | None = None,
         allow_redirects: bool | None = None,
-        timeout: httpx._types.TimeoutTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         extensions: dict[str, typing.Any] | None = None,
     ) -> httpx.Response:
         url = self._merge_url(url)
@@ -539,12 +508,10 @@ class TestClient(httpx.Client):
         params: httpx._types.QueryParamTypes | None = None,
         headers: httpx._types.HeaderTypes | None = None,
         cookies: httpx._types.CookieTypes | None = None,
-        auth: httpx._types.AuthTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         follow_redirects: bool | None = None,
         allow_redirects: bool | None = None,
-        timeout: httpx._types.TimeoutTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         extensions: dict[str, typing.Any] | None = None,
     ) -> httpx.Response:
         redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -566,12 +533,10 @@ class TestClient(httpx.Client):
         params: httpx._types.QueryParamTypes | None = None,
         headers: httpx._types.HeaderTypes | None = None,
         cookies: httpx._types.CookieTypes | None = None,
-        auth: httpx._types.AuthTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         follow_redirects: bool | None = None,
         allow_redirects: bool | None = None,
-        timeout: httpx._types.TimeoutTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         extensions: dict[str, typing.Any] | None = None,
     ) -> httpx.Response:
         redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -593,12 +558,10 @@ class TestClient(httpx.Client):
         params: httpx._types.QueryParamTypes | None = None,
         headers: httpx._types.HeaderTypes | None = None,
         cookies: httpx._types.CookieTypes | None = None,
-        auth: httpx._types.AuthTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         follow_redirects: bool | None = None,
         allow_redirects: bool | None = None,
-        timeout: httpx._types.TimeoutTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         extensions: dict[str, typing.Any] | None = None,
     ) -> httpx.Response:
         redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -624,12 +587,10 @@ class TestClient(httpx.Client):
         params: httpx._types.QueryParamTypes | None = None,
         headers: httpx._types.HeaderTypes | None = None,
         cookies: httpx._types.CookieTypes | None = None,
-        auth: httpx._types.AuthTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         follow_redirects: bool | None = None,
         allow_redirects: bool | None = None,
-        timeout: httpx._types.TimeoutTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         extensions: dict[str, typing.Any] | None = None,
     ) -> httpx.Response:
         redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -659,12 +620,10 @@ class TestClient(httpx.Client):
         params: httpx._types.QueryParamTypes | None = None,
         headers: httpx._types.HeaderTypes | None = None,
         cookies: httpx._types.CookieTypes | None = None,
-        auth: httpx._types.AuthTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         follow_redirects: bool | None = None,
         allow_redirects: bool | None = None,
-        timeout: httpx._types.TimeoutTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         extensions: dict[str, typing.Any] | None = None,
     ) -> httpx.Response:
         redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -694,12 +653,10 @@ class TestClient(httpx.Client):
         params: httpx._types.QueryParamTypes | None = None,
         headers: httpx._types.HeaderTypes | None = None,
         cookies: httpx._types.CookieTypes | None = None,
-        auth: httpx._types.AuthTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         follow_redirects: bool | None = None,
         allow_redirects: bool | None = None,
-        timeout: httpx._types.TimeoutTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         extensions: dict[str, typing.Any] | None = None,
     ) -> httpx.Response:
         redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -725,12 +682,10 @@ class TestClient(httpx.Client):
         params: httpx._types.QueryParamTypes | None = None,
         headers: httpx._types.HeaderTypes | None = None,
         cookies: httpx._types.CookieTypes | None = None,
-        auth: httpx._types.AuthTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         follow_redirects: bool | None = None,
         allow_redirects: bool | None = None,
-        timeout: httpx._types.TimeoutTypes
-        | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+        timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
         extensions: dict[str, typing.Any] | None = None,
     ) -> httpx.Response:
         redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
@@ -770,9 +725,7 @@ class TestClient(httpx.Client):
 
     def __enter__(self) -> TestClient:
         with contextlib.ExitStack() as stack:
-            self.portal = portal = stack.enter_context(
-                anyio.from_thread.start_blocking_portal(**self.async_backend)
-            )
+            self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
 
             @stack.callback
             def reset_portal() -> None:
index f78dd63ae08c0a77c902d00ee2add03ebecd7909..893f872964c2e6df9a81fdba3dd8fadfeaab9731 100644 (file)
@@ -16,15 +16,9 @@ Send = typing.Callable[[Message], typing.Awaitable[None]]
 ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
 
 StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]]
-StatefulLifespan = typing.Callable[
-    [AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]
-]
+StatefulLifespan = typing.Callable[[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]]
 Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
 
-HTTPExceptionHandler = typing.Callable[
-    ["Request", Exception], "Response | typing.Awaitable[Response]"
-]
-WebSocketExceptionHandler = typing.Callable[
-    ["WebSocket", Exception], typing.Awaitable[None]
-]
+HTTPExceptionHandler = typing.Callable[["Request", Exception], "Response | typing.Awaitable[Response]"]
+WebSocketExceptionHandler = typing.Callable[["WebSocket", Exception], typing.Awaitable[None]]
 ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler]
index 53ab5a70c8f1111cb06c56e927bfc2a1b21003b5..b7acaa3f0aa8e717c04c8fb21dd2d0a476f28d1f 100644 (file)
@@ -39,10 +39,7 @@ class WebSocket(HTTPConnection):
             message = await self._receive()
             message_type = message["type"]
             if message_type != "websocket.connect":
-                raise RuntimeError(
-                    'Expected ASGI message "websocket.connect", '
-                    f"but got {message_type!r}"
-                )
+                raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
             self.client_state = WebSocketState.CONNECTED
             return message
         elif self.client_state == WebSocketState.CONNECTED:
@@ -50,16 +47,13 @@ class WebSocket(HTTPConnection):
             message_type = message["type"]
             if message_type not in {"websocket.receive", "websocket.disconnect"}:
                 raise RuntimeError(
-                    'Expected ASGI message "websocket.receive" or '
-                    f'"websocket.disconnect", but got {message_type!r}'
+                    f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
                 )
             if message_type == "websocket.disconnect":
                 self.client_state = WebSocketState.DISCONNECTED
             return message
         else:
-            raise RuntimeError(
-                'Cannot call "receive" once a disconnect message has been received.'
-            )
+            raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
 
     async def send(self, message: Message) -> None:
         """
@@ -67,14 +61,9 @@ class WebSocket(HTTPConnection):
         """
         if self.application_state == WebSocketState.CONNECTING:
             message_type = message["type"]
-            if message_type not in {
-                "websocket.accept",
-                "websocket.close",
-                "websocket.http.response.start",
-            }:
+            if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
                 raise RuntimeError(
-                    'Expected ASGI message "websocket.accept",'
-                    '"websocket.close" or "websocket.http.response.start",'
+                    'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
                     f"but got {message_type!r}"
                 )
             if message_type == "websocket.close":
@@ -88,8 +77,7 @@ class WebSocket(HTTPConnection):
             message_type = message["type"]
             if message_type not in {"websocket.send", "websocket.close"}:
                 raise RuntimeError(
-                    'Expected ASGI message "websocket.send" or "websocket.close", '
-                    f"but got {message_type!r}"
+                    f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
                 )
             if message_type == "websocket.close":
                 self.application_state = WebSocketState.DISCONNECTED
@@ -101,10 +89,7 @@ class WebSocket(HTTPConnection):
         elif self.application_state == WebSocketState.RESPONSE:
             message_type = message["type"]
             if message_type != "websocket.http.response.body":
-                raise RuntimeError(
-                    'Expected ASGI message "websocket.http.response.body", '
-                    f"but got {message_type!r}"
-                )
+                raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
             if not message.get("more_body", False):
                 self.application_state = WebSocketState.DISCONNECTED
             await self._send(message)
@@ -121,9 +106,7 @@ class WebSocket(HTTPConnection):
         if self.client_state == WebSocketState.CONNECTING:
             # If we haven't yet seen the 'connect' message, then wait for it first.
             await self.receive()
-        await self.send(
-            {"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
-        )
+        await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
 
     def _raise_on_disconnect(self, message: Message) -> None:
         if message["type"] == "websocket.disconnect":
@@ -131,18 +114,14 @@ class WebSocket(HTTPConnection):
 
     async def receive_text(self) -> str:
         if self.application_state != WebSocketState.CONNECTED:
-            raise RuntimeError(
-                'WebSocket is not connected. Need to call "accept" first.'
-            )
+            raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
         message = await self.receive()
         self._raise_on_disconnect(message)
         return typing.cast(str, message["text"])
 
     async def receive_bytes(self) -> bytes:
         if self.application_state != WebSocketState.CONNECTED:
-            raise RuntimeError(
-                'WebSocket is not connected. Need to call "accept" first.'
-            )
+            raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
         message = await self.receive()
         self._raise_on_disconnect(message)
         return typing.cast(bytes, message["bytes"])
@@ -151,9 +130,7 @@ class WebSocket(HTTPConnection):
         if mode not in {"text", "binary"}:
             raise RuntimeError('The "mode" argument should be "text" or "binary".')
         if self.application_state != WebSocketState.CONNECTED:
-            raise RuntimeError(
-                'WebSocket is not connected. Need to call "accept" first.'
-            )
+            raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
         message = await self.receive()
         self._raise_on_disconnect(message)
 
@@ -200,17 +177,13 @@ class WebSocket(HTTPConnection):
             await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
 
     async def close(self, code: int = 1000, reason: str | None = None) -> None:
-        await self.send(
-            {"type": "websocket.close", "code": code, "reason": reason or ""}
-        )
+        await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
 
     async def send_denial_response(self, response: Response) -> None:
         if "websocket.http.response" in self.scope.get("extensions", {}):
             await response(self.scope, self.receive, self.send)
         else:
-            raise RuntimeError(
-                "The server doesn't support the Websocket Denial Response extension."
-            )
+            raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
 
 
 class WebSocketClose:
@@ -219,6 +192,4 @@ class WebSocketClose:
         self.reason = reason or ""
 
     async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
-        await send(
-            {"type": "websocket.close", "code": self.code, "reason": self.reason}
-        )
+        await send({"type": "websocket.close", "code": self.code, "reason": self.reason})
index 8e410cb1515da7e21c292a259b6f3bf8966058cf..22503865088bafcb62776500f566fa1571fc0d78 100644 (file)
@@ -169,9 +169,7 @@ def test_app_middleware_argument(test_client_factory: TestClientFactory) -> None
     def homepage(request: Request) -> PlainTextResponse:
         return PlainTextResponse("Homepage")
 
-    app = Starlette(
-        routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]
-    )
+    app = Starlette(routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)])
 
     client = test_client_factory(app)
     response = client.get("/")
@@ -249,9 +247,7 @@ def test_contextvars(
         ctxvar.set("set by endpoint")
         return PlainTextResponse("Homepage")
 
-    app = Starlette(
-        middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]
-    )
+    app = Starlette(middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)])
 
     client = test_client_factory(app)
     response = client.get("/")
@@ -316,13 +312,9 @@ async def test_do_not_block_on_background_tasks() -> None:
         events.append("Background task finished")
 
     async def endpoint_with_background_task(_: Request) -> PlainTextResponse:
-        return PlainTextResponse(
-            content="Hello", background=BackgroundTask(sleep_and_set)
-        )
+        return PlainTextResponse(content="Hello", background=BackgroundTask(sleep_and_set))
 
-    async def passthrough(
-        request: Request, call_next: RequestResponseEndpoint
-    ) -> Response:
+    async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response:
         return await call_next(request)
 
     app = Starlette(
@@ -490,9 +482,7 @@ def test_app_receives_http_disconnect_while_sending_if_discarded(
                         }
                     )
 
-            pytest.fail(
-                "http.disconnect should have been received and canceled the scope"
-            )  # pragma: no cover
+            pytest.fail("http.disconnect should have been received and canceled the scope")  # pragma: no cover
 
     app = DiscardingMiddleware(downstream_app)
 
@@ -787,7 +777,7 @@ async def test_read_request_stream_in_dispatch_wrapping_app_calls_body() -> None
     await rcv_stream.aclose()
 
 
-def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(  # noqa: E501
+def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(
     test_client_factory: TestClientFactory,
 ) -> None:
     async def homepage(request: Request) -> PlainTextResponse:
@@ -800,9 +790,7 @@ def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_ca
             request: Request,
             call_next: RequestResponseEndpoint,
         ) -> Response:
-            assert (
-                await request.body() == b"a"
-            )  # this buffers the request body in memory
+            assert await request.body() == b"a"  # this buffers the request body in memory
             resp = await call_next(request)
             async for chunk in request.stream():
                 if chunk:
@@ -819,7 +807,7 @@ def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_ca
     assert response.status_code == 200
 
 
-def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(  # noqa: E501
+def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(
     test_client_factory: TestClientFactory,
 ) -> None:
     async def homepage(request: Request) -> PlainTextResponse:
@@ -832,9 +820,7 @@ def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_call
             request: Request,
             call_next: RequestResponseEndpoint,
         ) -> Response:
-            assert (
-                await request.body() == b"a"
-            )  # this buffers the request body in memory
+            assert await request.body() == b"a"  # this buffers the request body in memory
             resp = await call_next(request)
             assert await request.body() == b"a"  # no problem here
             return resp
@@ -1026,9 +1012,7 @@ async def test_multiple_middlewares_stacked_client_disconnected() -> None:
             self.events = events
             super().__init__(app)
 
-        async def dispatch(
-            self, request: Request, call_next: RequestResponseEndpoint
-        ) -> Response:
+        async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
             self.events.append(f"{self.version}:STARTED")
             res = await call_next(request)
             self.events.append(f"{self.version}:COMPLETED")
@@ -1047,9 +1031,7 @@ async def test_multiple_middlewares_stacked_client_disconnected() -> None:
 
     app = Starlette(
         routes=[Route("/", sleepy)],
-        middleware=[
-            Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)
-        ],
+        middleware=[Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)],
     )
 
     scope = {
@@ -1114,9 +1096,7 @@ async def test_poll_for_disconnect_repeated(send_body: bool) -> None:
         await Response(b"good!")(scope, receive, send)
 
     class MyMiddleware(BaseHTTPMiddleware):
-        async def dispatch(
-            self, request: Request, call_next: RequestResponseEndpoint
-        ) -> Response:
+        async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
             return await call_next(request)
 
     app = MyMiddleware(app_poll_disconnect)
index 6303612434361006e505cd016be7344942195316..0d987263e7f857da9f1e8803bf1bfcc800d82f45 100644 (file)
@@ -252,9 +252,7 @@ def test_cors_preflight_allow_all_methods(
 
     app = Starlette(
         routes=[Route("/", endpoint=homepage)],
-        middleware=[
-            Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
-        ],
+        middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])],
     )
 
     client = test_client_factory(app)
@@ -284,9 +282,7 @@ def test_cors_allow_all_methods(
                 methods=["delete", "get", "head", "options", "patch", "post", "put"],
             )
         ],
-        middleware=[
-            Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
-        ],
+        middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])],
     )
 
     client = test_client_factory(app)
@@ -397,10 +393,7 @@ def test_cors_allow_origin_regex_fullmatch(
     response = client.get("/", headers=headers)
     assert response.status_code == 200
     assert response.text == "Homepage"
-    assert (
-        response.headers["access-control-allow-origin"]
-        == "https://subdomain.example.org"
-    )
+    assert response.headers["access-control-allow-origin"] == "https://subdomain.example.org"
     assert "access-control-allow-credentials" not in response.headers
 
     # Test diallowed standard response
@@ -456,9 +449,7 @@ def test_cors_vary_header_is_not_set_for_non_credentialed_request(
     test_client_factory: TestClientFactory,
 ) -> None:
     def homepage(request: Request) -> PlainTextResponse:
-        return PlainTextResponse(
-            "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
-        )
+        return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"})
 
     app = Starlette(
         routes=[Route("/", endpoint=homepage)],
@@ -475,9 +466,7 @@ def test_cors_vary_header_is_properly_set_for_credentialed_request(
     test_client_factory: TestClientFactory,
 ) -> None:
     def homepage(request: Request) -> PlainTextResponse:
-        return PlainTextResponse(
-            "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
-        )
+        return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"})
 
     app = Starlette(
         routes=[Route("/", endpoint=homepage)],
@@ -485,9 +474,7 @@ def test_cors_vary_header_is_properly_set_for_credentialed_request(
     )
     client = test_client_factory(app)
 
-    response = client.get(
-        "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}
-    )
+    response = client.get("/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"})
     assert response.status_code == 200
     assert response.headers["vary"] == "Accept-Encoding, Origin"
 
@@ -496,9 +483,7 @@ def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard(
     test_client_factory: TestClientFactory,
 ) -> None:
     def homepage(request: Request) -> PlainTextResponse:
-        return PlainTextResponse(
-            "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
-        )
+        return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"})
 
     app = Starlette(
         routes=[
@@ -538,9 +523,7 @@ def test_cors_allowed_origin_does_not_leak_between_credentialed_requests(
     assert response.headers["access-control-allow-origin"] == "*"
     assert "access-control-allow-credentials" not in response.headers
 
-    response = client.get(
-        "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}
-    )
+    response = client.get("/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"})
     assert response.headers["access-control-allow-origin"] == "https://someplace.org"
     assert "access-control-allow-credentials" not in response.headers
 
index b6f68296dc6da63cbbfad493d172cd4c9a6a1059..b20a7cb84b57478ac090ca5c8bdd2d20161f5226 100644 (file)
@@ -91,9 +91,7 @@ def test_gzip_ignored_for_responses_with_encoding_set(
                 yield bytes
 
         streaming = generator(bytes=b"x" * 400, count=10)
-        return StreamingResponse(
-            streaming, status_code=200, headers={"Content-Encoding": "text"}
-        )
+        return StreamingResponse(streaming, status_code=200, headers={"Content-Encoding": "text"})
 
     app = Starlette(
         routes=[Route("/", endpoint=homepage)],
index 9a0d70a0d7c93f3849ad7fa6c5bcadcdc9b4ca78..b4f3c64fac84d436210154a05075bcce9d7275cf 100644 (file)
@@ -89,9 +89,7 @@ def test_secure_session(test_client_factory: TestClientFactory) -> None:
             Route("/update_session", endpoint=update_session, methods=["POST"]),
             Route("/clear_session", endpoint=clear_session, methods=["POST"]),
         ],
-        middleware=[
-            Middleware(SessionMiddleware, secret_key="example", https_only=True)
-        ],
+        middleware=[Middleware(SessionMiddleware, secret_key="example", https_only=True)],
     )
     secure_client = test_client_factory(app, base_url="https://testserver")
     unsecure_client = test_client_factory(app, base_url="http://testserver")
@@ -126,9 +124,7 @@ def test_session_cookie_subpath(test_client_factory: TestClientFactory) -> None:
         routes=[
             Route("/update_session", endpoint=update_session, methods=["POST"]),
         ],
-        middleware=[
-            Middleware(SessionMiddleware, secret_key="example", path="/second_app")
-        ],
+        middleware=[Middleware(SessionMiddleware, secret_key="example", path="/second_app")],
     )
     app = Starlette(routes=[Mount("/second_app", app=second_app)])
     client = test_client_factory(app, base_url="http://testserver/second_app")
@@ -188,9 +184,7 @@ def test_domain_cookie(test_client_factory: TestClientFactory) -> None:
             Route("/view_session", endpoint=view_session),
             Route("/update_session", endpoint=update_session, methods=["POST"]),
         ],
-        middleware=[
-            Middleware(SessionMiddleware, secret_key="example", domain=".example.com")
-        ],
+        middleware=[Middleware(SessionMiddleware, secret_key="example", domain=".example.com")],
     )
     client: TestClient = test_client_factory(app)
 
index ddff46c48ca6bf63302f6d77a47a221eecb092e7..5b8b217c3ec9fcd47236b179a7f6ba48c5ee52e8 100644 (file)
@@ -13,11 +13,7 @@ def test_trusted_host_middleware(test_client_factory: TestClientFactory) -> None
 
     app = Starlette(
         routes=[Route("/", endpoint=homepage)],
-        middleware=[
-            Middleware(
-                TrustedHostMiddleware, allowed_hosts=["testserver", "*.testserver"]
-            )
-        ],
+        middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.testserver"])],
     )
 
     client = test_client_factory(app)
@@ -45,9 +41,7 @@ def test_www_redirect(test_client_factory: TestClientFactory) -> None:
 
     app = Starlette(
         routes=[Route("/", endpoint=homepage)],
-        middleware=[
-            Middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"])
-        ],
+        middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"])],
     )
 
     client = test_client_factory(app, base_url="https://example.com")
index 20da7ea8111b6031fcb595a5404132e025053cb0..86c713c38a64be87a8841b029cf9d0d7c83b66b2 100644 (file)
@@ -109,9 +109,7 @@ exception_handlers = {
     CustomWSException: custom_ws_exception_handler,
 }
 
-middleware = [
-    Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.example.org"])
-]
+middleware = [Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.example.org"])]
 
 app = Starlette(
     routes=[
@@ -349,9 +347,7 @@ def test_app_add_event_handler(test_client_factory: TestClientFactory) -> None:
         nonlocal cleanup_complete
         cleanup_complete = True
 
-    with pytest.deprecated_call(
-        match="The on_startup and on_shutdown parameters are deprecated"
-    ):
+    with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
         app = Starlette(
             on_startup=[run_startup],
             on_shutdown=[run_cleanup],
@@ -445,51 +441,34 @@ def test_decorator_deprecations() -> None:
     app = Starlette()
 
     with pytest.deprecated_call(
-        match=(
-            "The `exception_handler` decorator is deprecated, "
-            "and will be removed in version 1.0.0."
-        )
+        match=("The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0.")
     ) as record:
         app.exception_handler(500)(http_exception)
         assert len(record) == 1
 
     with pytest.deprecated_call(
-        match=(
-            "The `middleware` decorator is deprecated, "
-            "and will be removed in version 1.0.0."
-        )
+        match=("The `middleware` decorator is deprecated, and will be removed in version 1.0.0.")
     ) as record:
 
-        async def middleware(
-            request: Request, call_next: RequestResponseEndpoint
-        ) -> None: ...  # pragma: no cover
+        async def middleware(request: Request, call_next: RequestResponseEndpoint) -> None: ...  # pragma: no cover
 
         app.middleware("http")(middleware)
         assert len(record) == 1
 
     with pytest.deprecated_call(
-        match=(
-            "The `route` decorator is deprecated, "
-            "and will be removed in version 1.0.0."
-        )
+        match=("The `route` decorator is deprecated, and will be removed in version 1.0.0.")
     ) as record:
         app.route("/")(async_homepage)
         assert len(record) == 1
 
     with pytest.deprecated_call(
-        match=(
-            "The `websocket_route` decorator is deprecated, "
-            "and will be removed in version 1.0.0."
-        )
+        match=("The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0.")
     ) as record:
         app.websocket_route("/ws")(websocket_endpoint)
         assert len(record) == 1
 
     with pytest.deprecated_call(
-        match=(
-            "The `on_event` decorator is deprecated, "
-            "and will be removed in version 1.0.0."
-        )
+        match=("The `on_event` decorator is deprecated, and will be removed in version 1.0.0.")
     ) as record:
 
         async def startup() -> None: ...  # pragma: no cover
index 35c1110d14118bbb3438c8601fae7b81f339d031..a1bde67b9beb59f8b6a10705386d77f5aaf213fe 100644 (file)
@@ -259,9 +259,7 @@ def test_authentication_required(test_client_factory: TestClientFactory) -> None
         response = client.get("/dashboard/decorated")
         assert response.status_code == 403
 
-        response = client.get(
-            "/dashboard/decorated/sync", auth=("tomchristie", "example")
-        )
+        response = client.get("/dashboard/decorated/sync", auth=("tomchristie", "example"))
         assert response.status_code == 200
         assert response.json() == {
             "authenticated": True,
@@ -286,14 +284,10 @@ def test_websocket_authentication_required(
                 pass  # pragma: nocover
 
         with pytest.raises(WebSocketDisconnect):
-            with client.websocket_connect(
-                "/ws", headers={"Authorization": "basic foobar"}
-            ):
+            with client.websocket_connect("/ws", headers={"Authorization": "basic foobar"}):
                 pass  # pragma: nocover
 
-        with client.websocket_connect(
-            "/ws", auth=("tomchristie", "example")
-        ) as websocket:
+        with client.websocket_connect("/ws", auth=("tomchristie", "example")) as websocket:
             data = websocket.receive_json()
             assert data == {"authenticated": True, "user": "tomchristie"}
 
@@ -302,14 +296,10 @@ def test_websocket_authentication_required(
                 pass  # pragma: nocover
 
         with pytest.raises(WebSocketDisconnect):
-            with client.websocket_connect(
-                "/ws/decorated", headers={"Authorization": "basic foobar"}
-            ):
+            with client.websocket_connect("/ws/decorated", headers={"Authorization": "basic foobar"}):
                 pass  # pragma: nocover
 
-        with client.websocket_connect(
-            "/ws/decorated", auth=("tomchristie", "example")
-        ) as websocket:
+        with client.websocket_connect("/ws/decorated", auth=("tomchristie", "example")) as websocket:
             data = websocket.receive_json()
             assert data == {
                 "authenticated": True,
@@ -322,9 +312,7 @@ def test_authentication_redirect(test_client_factory: TestClientFactory) -> None
     with test_client_factory(app) as client:
         response = client.get("/admin")
         assert response.status_code == 200
-        url = "{}?{}".format(
-            "http://testserver/", urlencode({"next": "http://testserver/admin"})
-        )
+        url = "{}?{}".format("http://testserver/", urlencode({"next": "http://testserver/admin"}))
         assert response.url == url
 
         response = client.get("/admin", auth=("tomchristie", "example"))
@@ -333,9 +321,7 @@ def test_authentication_redirect(test_client_factory: TestClientFactory) -> None
 
         response = client.get("/admin/sync")
         assert response.status_code == 200
-        url = "{}?{}".format(
-            "http://testserver/", urlencode({"next": "http://testserver/admin/sync"})
-        )
+        url = "{}?{}".format("http://testserver/", urlencode({"next": "http://testserver/admin/sync"}))
         assert response.url == url
 
         response = client.get("/admin/sync", auth=("tomchristie", "example"))
@@ -359,11 +345,7 @@ def control_panel(request: Request) -> JSONResponse:
 
 other_app = Starlette(
     routes=[Route("/control-panel", control_panel)],
-    middleware=[
-        Middleware(
-            AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error
-        )
-    ],
+    middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error)],
 )
 
 
@@ -373,8 +355,6 @@ def test_custom_on_error(test_client_factory: TestClientFactory) -> None:
         assert response.status_code == 200
         assert response.json() == {"authenticated": True, "user": "tomchristie"}
 
-        response = client.get(
-            "/control-panel", headers={"Authorization": "basic foobar"}
-        )
+        response = client.get("/control-panel", headers={"Authorization": "basic foobar"})
         assert response.status_code == 401
         assert response.json() == {"error": "Invalid basic auth credentials"}
index cbffcc06a34565e60a22dc71807cace7d9a48045..990e270ea678bc59fa0f9b27f89eb3aeb2d756b5 100644 (file)
@@ -56,9 +56,7 @@ def test_multiple_tasks(test_client_factory: TestClientFactory) -> None:
         tasks.add_task(increment, amount=1)
         tasks.add_task(increment, amount=2)
         tasks.add_task(increment, amount=3)
-        response = Response(
-            "tasks initiated", media_type="text/plain", background=tasks
-        )
+        response = Response("tasks initiated", media_type="text/plain", background=tasks)
         await response(scope, receive, send)
 
     client = test_client_factory(app)
@@ -82,9 +80,7 @@ def test_multi_tasks_failure_avoids_next_execution(
         tasks = BackgroundTasks()
         tasks.add_task(increment)
         tasks.add_task(increment)
-        response = Response(
-            "tasks initiated", media_type="text/plain", background=tasks
-        )
+        response = Response("tasks initiated", media_type="text/plain", background=tasks)
         await response(scope, receive, send)
 
     client = test_client_factory(app)
index f375910077e4e7eeda8c5cf86a5bc5062fe82c9d..7d2cd1f9de6d568d6b9bad66235dbb5ef98eabd6 100644 (file)
@@ -14,9 +14,7 @@ def test_config_types() -> None:
     """
     We use `assert_type` to test the types returned by Config via mypy.
     """
-    config = Config(
-        environ={"STR": "some_str_value", "STR_CAST": "some_str_value", "BOOL": "true"}
-    )
+    config = Config(environ={"STR": "some_str_value", "STR_CAST": "some_str_value", "BOOL": "true"})
 
     assert_type(config("STR"), str)
     assert_type(config("STR_DEFAULT", default=""), str)
@@ -138,9 +136,7 @@ def test_environ() -> None:
 
 
 def test_config_with_env_prefix(tmpdir: Path, monkeypatch: pytest.MonkeyPatch) -> None:
-    config = Config(
-        environ={"APP_DEBUG": "value", "ENVIRONMENT": "dev"}, env_prefix="APP_"
-    )
+    config = Config(environ={"APP_DEBUG": "value", "ENVIRONMENT": "dev"}, env_prefix="APP_")
     assert config.get("DEBUG") == "value"
 
     with pytest.raises(KeyError):
index 520c987677fffa921a40a8bfdc936e1c7efbd8d2..ced1b86cc773e5f11f5b19a6697fea88c215c935 100644 (file)
@@ -48,23 +48,18 @@ def app() -> Router:
     )
 
 
-def test_datetime_convertor(
-    test_client_factory: TestClientFactory, app: Router
-) -> None:
+def test_datetime_convertor(test_client_factory: TestClientFactory, app: Router) -> None:
     client = test_client_factory(app)
     response = client.get("/datetime/2020-01-01T00:00:00")
     assert response.json() == {"datetime": "2020-01-01T00:00:00"}
 
     assert (
-        app.url_path_for("datetime-convertor", param=datetime(1996, 1, 22, 23, 0, 0))
-        == "/datetime/1996-01-22T23:00:00"
+        app.url_path_for("datetime-convertor", param=datetime(1996, 1, 22, 23, 0, 0)) == "/datetime/1996-01-22T23:00:00"
     )
 
 
 @pytest.mark.parametrize("param, status_code", [("1.0", 200), ("1-0", 404)])
-def test_default_float_convertor(
-    test_client_factory: TestClientFactory, param: str, status_code: int
-) -> None:
+def test_default_float_convertor(test_client_factory: TestClientFactory, param: str, status_code: int) -> None:
     def float_convertor(request: Request) -> JSONResponse:
         param = request.path_params["param"]
         assert isinstance(param, float)
index a6bca6ef6d6b86de9c5b95322c5d2b4b6b588c01..0e7d35c3c3d210e9bccf2cfe5589f86e4d0e856d 100644 (file)
@@ -115,9 +115,7 @@ def test_csv() -> None:
 
 
 def test_url_from_scope() -> None:
-    u = URL(
-        scope={"path": "/path/to/somewhere", "query_string": b"abc=123", "headers": []}
-    )
+    u = URL(scope={"path": "/path/to/somewhere", "query_string": b"abc=123", "headers": []})
     assert u == "/path/to/somewhere?abc=123"
     assert repr(u) == "URL('/path/to/somewhere?abc=123')"
 
@@ -296,13 +294,9 @@ def test_queryparams() -> None:
     assert dict(q) == {"a": "456", "b": "789"}
     assert str(q) == "a=123&a=456&b=789"
     assert repr(q) == "QueryParams('a=123&a=456&b=789')"
-    assert QueryParams({"a": "123", "b": "456"}) == QueryParams(
-        [("a", "123"), ("b", "456")]
-    )
+    assert QueryParams({"a": "123", "b": "456"}) == QueryParams([("a", "123"), ("b", "456")])
     assert QueryParams({"a": "123", "b": "456"}) == QueryParams("a=123&b=456")
-    assert QueryParams({"a": "123", "b": "456"}) == QueryParams(
-        {"b": "456", "a": "123"}
-    )
+    assert QueryParams({"a": "123", "b": "456"}) == QueryParams({"b": "456", "a": "123"})
     assert QueryParams() == QueryParams({})
     assert QueryParams([("a", "123"), ("a", "456")]) == QueryParams("a=123&a=456")
     assert QueryParams({"a": "123", "b": "456"}) != "invalid"
@@ -382,10 +376,7 @@ def test_formdata() -> None:
     assert len(form) == 2
     assert list(form) == ["a", "b"]
     assert dict(form) == {"a": "456", "b": upload}
-    assert (
-        repr(form)
-        == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])"
-    )
+    assert repr(form) == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])"
     assert FormData(form) == form
     assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")])
     assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"}
@@ -402,10 +393,7 @@ async def test_upload_file_repr() -> None:
 async def test_upload_file_repr_headers() -> None:
     stream = io.BytesIO(b"data")
     file = UploadFile(filename="file", file=stream, headers=Headers({"foo": "bar"}))
-    assert (
-        repr(file)
-        == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))"
-    )
+    assert repr(file) == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))"
 
 
 def test_multidict() -> None:
@@ -425,9 +413,7 @@ def test_multidict() -> None:
     assert dict(q) == {"a": "456", "b": "789"}
     assert str(q) == "MultiDict([('a', '123'), ('a', '456'), ('b', '789')])"
     assert repr(q) == "MultiDict([('a', '123'), ('a', '456'), ('b', '789')])"
-    assert MultiDict({"a": "123", "b": "456"}) == MultiDict(
-        [("a", "123"), ("b", "456")]
-    )
+    assert MultiDict({"a": "123", "b": "456"}) == MultiDict([("a", "123"), ("b", "456")])
     assert MultiDict({"a": "123", "b": "456"}) == MultiDict({"b": "456", "a": "123"})
     assert MultiDict() == MultiDict({})
     assert MultiDict({"a": "123", "b": "456"}) != "invalid"
index 8f201e25be0407a3b19e14a8d5f2c26ede802077..42776a5b32243fd6ebccde8771ff43ac62eb01d5 100644 (file)
@@ -19,9 +19,7 @@ class Homepage(HTTPEndpoint):
         return PlainTextResponse(f"Hello, {username}!")
 
 
-app = Router(
-    routes=[Route("/", endpoint=Homepage), Route("/{username}", endpoint=Homepage)]
-)
+app = Router(routes=[Route("/", endpoint=Homepage), Route("/{username}", endpoint=Homepage)])
 
 
 @pytest.fixture
index f4e91ad871bae26bc59ea19a55250f7fe292f3c9..b3dc7843fbb5fc5ba7f8ccf21b9d1af879edf865 100644 (file)
@@ -42,9 +42,7 @@ async def read_body_and_raise_exc(request: Request) -> None:
     raise BadBodyException(422)
 
 
-async def handler_that_reads_body(
-    request: Request, exc: BadBodyException
-) -> JSONResponse:
+async def handler_that_reads_body(request: Request, exc: BadBodyException) -> JSONResponse:
     body = await request.body()
     return JSONResponse(status_code=422, content={"body": body.decode()})
 
@@ -158,9 +156,7 @@ def test_http_str() -> None:
 
 
 def test_http_repr() -> None:
-    assert repr(HTTPException(404)) == (
-        "HTTPException(status_code=404, detail='Not Found')"
-    )
+    assert repr(HTTPException(404)) == ("HTTPException(status_code=404, detail='Not Found')")
     assert repr(HTTPException(404, detail="Not Found: foo")) == (
         "HTTPException(status_code=404, detail='Not Found: foo')"
     )
index 8d97a0ba7b8faadb38d002f4e8c6cc381fbc0ba3..61c1bede19f8af000453e43705ea11a05df2e7bf 100644 (file)
@@ -127,17 +127,13 @@ def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000) -> ASGIApp
     return app
 
 
-def test_multipart_request_data(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app)
     response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART)
     assert response.json() == {"some": "data"}
 
 
-def test_multipart_request_files(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "test.txt")
     with open(path, "wb") as file:
         file.write(b"<file content>")
@@ -155,9 +151,7 @@ def test_multipart_request_files(
         }
 
 
-def test_multipart_request_files_with_content_type(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_files_with_content_type(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "test.txt")
     with open(path, "wb") as file:
         file.write(b"<file content>")
@@ -175,9 +169,7 @@ def test_multipart_request_files_with_content_type(
         }
 
 
-def test_multipart_request_multiple_files(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_multiple_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path1 = os.path.join(tmpdir, "test1.txt")
     with open(path1, "wb") as file:
         file.write(b"<file1 content>")
@@ -188,9 +180,7 @@ def test_multipart_request_multiple_files(
 
     client = test_client_factory(app)
     with open(path1, "rb") as f1, open(path2, "rb") as f2:
-        response = client.post(
-            "/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")}
-        )
+        response = client.post("/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")})
         assert response.json() == {
             "test1": {
                 "filename": "test1.txt",
@@ -207,9 +197,7 @@ def test_multipart_request_multiple_files(
         }
 
 
-def test_multipart_request_multiple_files_with_headers(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_multiple_files_with_headers(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path1 = os.path.join(tmpdir, "test1.txt")
     with open(path1, "wb") as file:
         file.write(b"<file1 content>")
@@ -281,9 +269,7 @@ def test_multi_items(tmpdir: Path, test_client_factory: TestClientFactory) -> No
         }
 
 
-def test_multipart_request_mixed_files_and_data(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app)
     response = client.post(
         "/",
@@ -303,11 +289,7 @@ def test_multipart_request_mixed_files_and_data(
             b"value1\r\n"
             b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n"
         ),
-        headers={
-            "Content-Type": (
-                "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c"
-            )
-        },
+        headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
     )
     assert response.json() == {
         "file": {
@@ -321,26 +303,19 @@ def test_multipart_request_mixed_files_and_data(
     }
 
 
-def test_multipart_request_with_charset_for_filename(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app)
     response = client.post(
         "/",
         data=(
             # file
             b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"  # type: ignore
-            b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n'  # noqa: E501
+            b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n'
             b"Content-Type: text/plain\r\n\r\n"
             b"<file content>\r\n"
             b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n"
         ),
-        headers={
-            "Content-Type": (
-                "multipart/form-data; charset=utf-8; "
-                "boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c"
-            )
-        },
+        headers={"Content-Type": ("multipart/form-data; charset=utf-8; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
     )
     assert response.json() == {
         "file": {
@@ -352,25 +327,19 @@ def test_multipart_request_with_charset_for_filename(
     }
 
 
-def test_multipart_request_without_charset_for_filename(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_without_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app)
     response = client.post(
         "/",
         data=(
             # file
             b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n"  # type: ignore
-            b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n'  # noqa: E501
+            b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n'
             b"Content-Type: image/jpeg\r\n\r\n"
             b"<file content>\r\n"
             b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n"
         ),
-        headers={
-            "Content-Type": (
-                "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c"
-            )
-        },
+        headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
     )
     assert response.json() == {
         "file": {
@@ -382,9 +351,7 @@ def test_multipart_request_without_charset_for_filename(
     }
 
 
-def test_multipart_request_with_encoded_value(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_with_encoded_value(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app)
     response = client.post(
         "/",
@@ -395,19 +362,12 @@ def test_multipart_request_with_encoded_value(
             b"Transf\xc3\xa9rer\r\n"
             b"--20b303e711c4ab8c443184ac833ab00f--\r\n"
         ),
-        headers={
-            "Content-Type": (
-                "multipart/form-data; charset=utf-8; "
-                "boundary=20b303e711c4ab8c443184ac833ab00f"
-            )
-        },
+        headers={"Content-Type": ("multipart/form-data; charset=utf-8; boundary=20b303e711c4ab8c443184ac833ab00f")},
     )
     assert response.json() == {"value": "Transférer"}
 
 
-def test_urlencoded_request_data(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_urlencoded_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app)
     response = client.post("/", data={"some": "data"})
     assert response.json() == {"some": "data"}
@@ -419,37 +379,27 @@ def test_no_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -
     assert response.json() == {}
 
 
-def test_urlencoded_percent_encoding(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_urlencoded_percent_encoding(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app)
     response = client.post("/", data={"some": "da ta"})
     assert response.json() == {"some": "da ta"}
 
 
-def test_urlencoded_percent_encoding_keys(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_urlencoded_percent_encoding_keys(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app)
     response = client.post("/", data={"so me": "data"})
     assert response.json() == {"so me": "data"}
 
 
-def test_urlencoded_multi_field_app_reads_body(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_urlencoded_multi_field_app_reads_body(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app_read_body)
     response = client.post("/", data={"some": "data", "second": "key pair"})
     assert response.json() == {"some": "data", "second": "key pair"}
 
 
-def test_multipart_multi_field_app_reads_body(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_multi_field_app_reads_body(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     client = test_client_factory(app_read_body)
-    response = client.post(
-        "/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART
-    )
+    response = client.post("/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART)
     assert response.json() == {"some": "data", "second": "key pair"}
 
 
@@ -481,7 +431,7 @@ def test_missing_boundary_parameter(
             "/",
             data=(
                 # file
-                b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n'  # type: ignore # noqa: E501
+                b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n'  # type: ignore
                 b"Content-Type: text/plain\r\n\r\n"
                 b"<file content>\r\n"
             ),
@@ -513,16 +463,10 @@ def test_missing_name_parameter_on_content_disposition(
                 b'Content-Disposition: form-data; ="field0"\r\n\r\n'
                 b"value0\r\n"
             ),
-            headers={
-                "Content-Type": (
-                    "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c"
-                )
-            },
+            headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
         )
         assert res.status_code == 400
-        assert (
-            res.text == 'The Content-Disposition header field "name" must be provided.'
-        )
+        assert res.text == 'The Content-Disposition header field "name" must be provided.'
 
 
 @pytest.mark.parametrize(
@@ -540,9 +484,7 @@ def test_too_many_fields_raise(
     client = test_client_factory(app)
     fields = []
     for i in range(1001):
-        fields.append(
-            "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n"
-        )
+        fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n")
     data = "".join(fields).encode("utf-8")
     with expectation:
         res = client.post(
@@ -569,11 +511,7 @@ def test_too_many_files_raise(
     client = test_client_factory(app)
     fields = []
     for i in range(1001):
-        fields.append(
-            "--B\r\n"
-            f'Content-Disposition: form-data; name="N{i}"; filename="F{i}";\r\n\r\n'
-            "\r\n"
-        )
+        fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}"; filename="F{i}";\r\n\r\n' "\r\n")
     data = "".join(fields).encode("utf-8")
     with expectation:
         res = client.post(
@@ -602,11 +540,7 @@ def test_too_many_files_single_field_raise(
     for i in range(1001):
         # This uses the same field name "N" for all files, equivalent to a
         # multifile upload form field
-        fields.append(
-            "--B\r\n"
-            f'Content-Disposition: form-data; name="N"; filename="F{i}";\r\n\r\n'
-            "\r\n"
-        )
+        fields.append("--B\r\n" f'Content-Disposition: form-data; name="N"; filename="F{i}";\r\n\r\n' "\r\n")
     data = "".join(fields).encode("utf-8")
     with expectation:
         res = client.post(
@@ -633,14 +567,8 @@ def test_too_many_files_and_fields_raise(
     client = test_client_factory(app)
     fields = []
     for i in range(1001):
-        fields.append(
-            "--B\r\n"
-            f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n'
-            "\r\n"
-        )
-        fields.append(
-            "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n"
-        )
+        fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n")
+        fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n")
     data = "".join(fields).encode("utf-8")
     with expectation:
         res = client.post(
@@ -670,9 +598,7 @@ def test_max_fields_is_customizable_low_raises(
     client = test_client_factory(app)
     fields = []
     for i in range(2):
-        fields.append(
-            "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n"
-        )
+        fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n")
     data = "".join(fields).encode("utf-8")
     with expectation:
         res = client.post(
@@ -702,11 +628,7 @@ def test_max_files_is_customizable_low_raises(
     client = test_client_factory(app)
     fields = []
     for i in range(2):
-        fields.append(
-            "--B\r\n"
-            f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n'
-            "\r\n"
-        )
+        fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n")
     data = "".join(fields).encode("utf-8")
     with expectation:
         res = client.post(
@@ -724,14 +646,8 @@ def test_max_fields_is_customizable_high(
     client = test_client_factory(make_app_max_parts(max_fields=2000, max_files=2000))
     fields = []
     for i in range(2000):
-        fields.append(
-            "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n"
-        )
-        fields.append(
-            "--B\r\n"
-            f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n'
-            "\r\n"
-        )
+        fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n")
+        fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n")
     data = "".join(fields).encode("utf-8")
     data += b"--B--\r\n"
     res = client.post(
index c63c92de583f7aa1dca55d123359688799177394..ad1901ca5c3cdae94c35fd0b7c5d9af9ad92c532 100644 (file)
@@ -118,9 +118,7 @@ def test_streaming_response(test_client_factory: TestClientFactory) -> None:
 
         cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9)
         generator = numbers(1, 5)
-        response = StreamingResponse(
-            generator, media_type="text/plain", background=cleanup_task
-        )
+        response = StreamingResponse(generator, media_type="text/plain", background=cleanup_task)
         await response(scope, receive, send)
 
     assert filled_by_bg_task == ""
@@ -236,9 +234,7 @@ def test_file_response(tmp_path: Path, test_client_factory: TestClientFactory) -
     cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9)
 
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
-        response = FileResponse(
-            path=path, filename="example.png", background=cleanup_task
-        )
+        response = FileResponse(path=path, filename="example.png", background=cleanup_task)
         await response(scope, receive, send)
 
     assert filled_by_bg_task == ""
@@ -284,9 +280,7 @@ async def test_file_response_on_head_method(tmp_path: Path) -> None:
     await app({"type": "http", "method": "head"}, receive, send)
 
 
-def test_file_response_set_media_type(
-    tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_set_media_type(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
     path = tmp_path / "xyz"
     path.write_bytes(b"<file content>")
 
@@ -298,9 +292,7 @@ def test_file_response_set_media_type(
     assert response.headers["content-type"] == "image/jpeg"
 
 
-def test_file_response_with_directory_raises_error(
-    tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_with_directory_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
     app = FileResponse(path=tmp_path, filename="example.png")
     client = test_client_factory(app)
     with pytest.raises(RuntimeError) as exc_info:
@@ -308,9 +300,7 @@ def test_file_response_with_directory_raises_error(
     assert "is not a file" in str(exc_info.value)
 
 
-def test_file_response_with_missing_file_raises_error(
-    tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_with_missing_file_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
     path = tmp_path / "404.txt"
     app = FileResponse(path=path, filename="404.txt")
     client = test_client_factory(app)
@@ -319,9 +309,7 @@ def test_file_response_with_missing_file_raises_error(
     assert "does not exist" in str(exc_info.value)
 
 
-def test_file_response_with_chinese_filename(
-    tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_with_chinese_filename(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
     content = b"file content"
     filename = "你好.txt"  # probably "Hello.txt" in Chinese
     path = tmp_path / filename
@@ -335,9 +323,7 @@ def test_file_response_with_chinese_filename(
     assert response.headers["content-disposition"] == expected_disposition
 
 
-def test_file_response_with_inline_disposition(
-    tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_with_inline_disposition(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
     content = b"file content"
     filename = "hello.txt"
     path = tmp_path / filename
@@ -356,9 +342,7 @@ def test_file_response_with_method_warns(tmp_path: Path) -> None:
         FileResponse(path=tmp_path, filename="example.png", method="GET")
 
 
-def test_set_cookie(
-    test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch
-) -> None:
+def test_set_cookie(test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch) -> None:
     # Mock time used as a reference for `Expires` by stdlib `SimpleCookie`.
     mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc)
     monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp())
@@ -382,8 +366,7 @@ def test_set_cookie(
     response = client.get("/")
     assert response.text == "Hello, world!"
     assert (
-        response.headers["set-cookie"]
-        == "mycookie=myvalue; Domain=localhost; expires=Thu, 22 Jan 2037 12:00:10 GMT; "
+        response.headers["set-cookie"] == "mycookie=myvalue; Domain=localhost; expires=Thu, 22 Jan 2037 12:00:10 GMT; "
         "HttpOnly; Max-Age=10; Path=/; SameSite=none; Secure"
     )
 
@@ -403,9 +386,7 @@ def test_set_cookie_path_none(test_client_factory: TestClientFactory) -> None:
 @pytest.mark.parametrize(
     "expires",
     [
-        pytest.param(
-            dt.datetime(2037, 1, 22, 12, 0, 10, tzinfo=dt.timezone.utc), id="datetime"
-        ),
+        pytest.param(dt.datetime(2037, 1, 22, 12, 0, 10, tzinfo=dt.timezone.utc), id="datetime"),
         pytest.param("Thu, 22 Jan 2037 12:00:10 GMT", id="str"),
         pytest.param(10, id="int"),
     ],
@@ -495,9 +476,7 @@ def test_response_do_not_add_redundant_charset(
     assert response.headers["content-type"] == "text/plain; charset=utf-8"
 
 
-def test_file_response_known_size(
-    tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_known_size(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
     path = tmp_path / "xyz"
     content = b"<file content>" * 1000
     path.write_bytes(content)
@@ -518,9 +497,7 @@ def test_streaming_response_unknown_size(
 
 
 def test_streaming_response_known_size(test_client_factory: TestClientFactory) -> None:
-    app = StreamingResponse(
-        content=iter(["hello", "world"]), headers={"content-length": "10"}
-    )
+    app = StreamingResponse(content=iter(["hello", "world"]), headers={"content-length": "10"})
     client: TestClient = test_client_factory(app)
     response = client.get("/")
     assert response.headers["content-length"] == "10"
index 1490723b4376993156b78ceb160f64f538b28a20..9fa44def4c04c3a58c58e7c678fb13db85472e9b 100644 (file)
@@ -232,10 +232,7 @@ def test_route_converters(client: TestClient) -> None:
     response = client.get("/path-with-parentheses(7)")
     assert response.status_code == 200
     assert response.json() == {"int": 7}
-    assert (
-        app.url_path_for("path-with-parentheses", param=7)
-        == "/path-with-parentheses(7)"
-    )
+    assert app.url_path_for("path-with-parentheses", param=7) == "/path-with-parentheses(7)"
 
     # Test float conversion
     response = client.get("/float/25.5")
@@ -247,18 +244,14 @@ def test_route_converters(client: TestClient) -> None:
     response = client.get("/path/some/example")
     assert response.status_code == 200
     assert response.json() == {"path": "some/example"}
-    assert (
-        app.url_path_for("path-convertor", param="some/example") == "/path/some/example"
-    )
+    assert app.url_path_for("path-convertor", param="some/example") == "/path/some/example"
 
     # Test UUID conversion
     response = client.get("/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a")
     assert response.status_code == 200
     assert response.json() == {"uuid": "ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"}
     assert (
-        app.url_path_for(
-            "uuid-convertor", param=uuid.UUID("ec38df32-ceda-4cfa-9b4a-1aeb94ad551a")
-        )
+        app.url_path_for("uuid-convertor", param=uuid.UUID("ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"))
         == "/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"
     )
 
@@ -267,13 +260,9 @@ def test_url_path_for() -> None:
     assert app.url_path_for("homepage") == "/"
     assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie"
     assert app.url_path_for("websocket_endpoint") == "/ws"
-    with pytest.raises(
-        NoMatchFound, match='No route exists for name "broken" and params "".'
-    ):
+    with pytest.raises(NoMatchFound, match='No route exists for name "broken" and params "".'):
         assert app.url_path_for("broken")
-    with pytest.raises(
-        NoMatchFound, match='No route exists for name "broken" and params "key, key2".'
-    ):
+    with pytest.raises(NoMatchFound, match='No route exists for name "broken" and params "key, key2".'):
         assert app.url_path_for("broken", key="value", key2="value2")
     with pytest.raises(AssertionError):
         app.url_path_for("user", username="tom/christie")
@@ -282,32 +271,21 @@ def test_url_path_for() -> None:
 
 
 def test_url_for() -> None:
+    assert app.url_path_for("homepage").make_absolute_url(base_url="https://example.org") == "https://example.org/"
     assert (
-        app.url_path_for("homepage").make_absolute_url(base_url="https://example.org")
-        == "https://example.org/"
-    )
-    assert (
-        app.url_path_for("homepage").make_absolute_url(
-            base_url="https://example.org/root_path/"
-        )
+        app.url_path_for("homepage").make_absolute_url(base_url="https://example.org/root_path/")
         == "https://example.org/root_path/"
     )
     assert (
-        app.url_path_for("user", username="tomchristie").make_absolute_url(
-            base_url="https://example.org"
-        )
+        app.url_path_for("user", username="tomchristie").make_absolute_url(base_url="https://example.org")
         == "https://example.org/users/tomchristie"
     )
     assert (
-        app.url_path_for("user", username="tomchristie").make_absolute_url(
-            base_url="https://example.org/root_path/"
-        )
+        app.url_path_for("user", username="tomchristie").make_absolute_url(base_url="https://example.org/root_path/")
         == "https://example.org/root_path/users/tomchristie"
     )
     assert (
-        app.url_path_for("websocket_endpoint").make_absolute_url(
-            base_url="https://example.org"
-        )
+        app.url_path_for("websocket_endpoint").make_absolute_url(base_url="https://example.org")
         == "wss://example.org/ws"
     )
 
@@ -409,13 +387,8 @@ def test_reverse_mount_urls() -> None:
 
     users = Router([Route("/{username}", ok, name="user")])
     mounted = Router([Mount("/{subpath}/users", users, name="users")])
-    assert (
-        mounted.url_path_for("users:user", subpath="test", username="tom")
-        == "/test/users/tom"
-    )
-    assert (
-        mounted.url_path_for("users", subpath="test", path="/tom") == "/test/users/tom"
-    )
+    assert mounted.url_path_for("users:user", subpath="test", username="tom") == "/test/users/tom"
+    assert mounted.url_path_for("users", subpath="test", path="/tom") == "/test/users/tom"
 
 
 def test_mount_at_root(test_client_factory: TestClientFactory) -> None:
@@ -472,9 +445,7 @@ def test_host_routing(test_client_factory: TestClientFactory) -> None:
     response = client.get("/")
     assert response.status_code == 200
 
-    client = test_client_factory(
-        mixed_hosts_app, base_url="https://port.example.org:3600/"
-    )
+    client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:3600/")
 
     response = client.get("/users")
     assert response.status_code == 404
@@ -489,31 +460,23 @@ def test_host_routing(test_client_factory: TestClientFactory) -> None:
     response = client.get("/")
     assert response.status_code == 200
 
-    client = test_client_factory(
-        mixed_hosts_app, base_url="https://port.example.org:5600/"
-    )
+    client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:5600/")
 
     response = client.get("/")
     assert response.status_code == 200
 
 
 def test_host_reverse_urls() -> None:
+    assert mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever") == "https://www.example.org/"
     assert (
-        mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever")
-        == "https://www.example.org/"
-    )
-    assert (
-        mixed_hosts_app.url_path_for("users").make_absolute_url("https://whatever")
-        == "https://www.example.org/users"
+        mixed_hosts_app.url_path_for("users").make_absolute_url("https://whatever") == "https://www.example.org/users"
     )
     assert (
         mixed_hosts_app.url_path_for("api:users").make_absolute_url("https://whatever")
         == "https://api.example.org/users"
     )
     assert (
-        mixed_hosts_app.url_path_for("port:homepage").make_absolute_url(
-            "https://whatever"
-        )
+        mixed_hosts_app.url_path_for("port:homepage").make_absolute_url("https://whatever")
         == "https://port.example.org:3600/"
     )
 
@@ -523,9 +486,7 @@ async def subdomain_app(scope: Scope, receive: Receive, send: Send) -> None:
     await response(scope, receive, send)
 
 
-subdomain_router = Router(
-    routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")]
-)
+subdomain_router = Router(routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")])
 
 
 def test_subdomain_routing(test_client_factory: TestClientFactory) -> None:
@@ -538,9 +499,9 @@ def test_subdomain_routing(test_client_factory: TestClientFactory) -> None:
 
 def test_subdomain_reverse_urls() -> None:
     assert (
-        subdomain_router.url_path_for(
-            "subdomains", subdomain="foo", path="/homepage"
-        ).make_absolute_url("https://whatever")
+        subdomain_router.url_path_for("subdomains", subdomain="foo", path="/homepage").make_absolute_url(
+            "https://whatever"
+        )
         == "https://foo.example.org/homepage"
     )
 
@@ -566,9 +527,7 @@ echo_url_routes = [
 
 def test_url_for_with_root_path(test_client_factory: TestClientFactory) -> None:
     app = Starlette(routes=echo_url_routes)
-    client = test_client_factory(
-        app, base_url="https://www.example.org/", root_path="/sub_path"
-    )
+    client = test_client_factory(app, base_url="https://www.example.org/", root_path="/sub_path")
     response = client.get("/sub_path/")
     assert response.json() == {
         "index": "https://www.example.org/sub_path/",
@@ -657,9 +616,7 @@ def test_lifespan_async(test_client_factory: TestClientFactory) -> None:
         nonlocal shutdown_complete
         shutdown_complete = True
 
-    with pytest.deprecated_call(
-        match="The on_startup and on_shutdown parameters are deprecated"
-    ):
+    with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
         app = Router(
             on_startup=[run_startup],
             on_shutdown=[run_shutdown],
@@ -697,18 +654,11 @@ def test_lifespan_with_on_events(test_client_factory: TestClientFactory) -> None
         nonlocal shutdown_called
         shutdown_called = True
 
-    with pytest.deprecated_call(
-        match="The on_startup and on_shutdown parameters are deprecated"
-    ):
+    with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
         with pytest.warns(
-            UserWarning,
-            match=(
-                "The `lifespan` parameter cannot be used with `on_startup` or `on_shutdown`."  # noqa: E501
-            ),
+            UserWarning, match="The `lifespan` parameter cannot be used with `on_startup` or `on_shutdown`."
         ):
-            app = Router(
-                on_startup=[run_startup], on_shutdown=[run_shutdown], lifespan=lifespan
-            )
+            app = Router(on_startup=[run_startup], on_shutdown=[run_shutdown], lifespan=lifespan)
 
             assert not lifespan_called
             assert not startup_called
@@ -738,9 +688,7 @@ def test_lifespan_sync(test_client_factory: TestClientFactory) -> None:
         nonlocal shutdown_complete
         shutdown_complete = True
 
-    with pytest.deprecated_call(
-        match="The on_startup and on_shutdown parameters are deprecated"
-    ):
+    with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
         app = Router(
             on_startup=[run_startup],
             on_shutdown=[run_shutdown],
@@ -775,9 +723,7 @@ def test_lifespan_state_unsupported(
         del scope["state"]
         await app(scope, receive, send)
 
-    with pytest.raises(
-        RuntimeError, match='The server does not support "state" in the lifespan scope'
-    ):
+    with pytest.raises(RuntimeError, match='The server does not support "state" in the lifespan scope'):
         with test_client_factory(no_state_wrapper):
             raise AssertionError("Should not be called")  # pragma: no cover
 
@@ -834,9 +780,7 @@ def test_raise_on_startup(test_client_factory: TestClientFactory) -> None:
     def run_startup() -> None:
         raise RuntimeError()
 
-    with pytest.deprecated_call(
-        match="The on_startup and on_shutdown parameters are deprecated"
-    ):
+    with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
         router = Router(on_startup=[run_startup])
     startup_failed = False
 
@@ -859,9 +803,7 @@ def test_raise_on_shutdown(test_client_factory: TestClientFactory) -> None:
     def run_shutdown() -> None:
         raise RuntimeError()
 
-    with pytest.deprecated_call(
-        match="The on_startup and on_shutdown parameters are deprecated"
-    ):
+    with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
         app = Router(on_shutdown=[run_shutdown])
 
     with pytest.raises(RuntimeError):
@@ -934,9 +876,7 @@ class Endpoint:
         pytest.param(lambda request: ..., "<lambda>", id="lambda"),
     ],
 )
-def test_route_name(
-    endpoint: typing.Callable[..., Response], expected_name: str
-) -> None:
+def test_route_name(endpoint: typing.Callable[..., Response], expected_name: str) -> None:
     assert Route(path="/", endpoint=endpoint).name == expected_name
 
 
@@ -1172,10 +1112,7 @@ def test_websocket_route_middleware(
 
 def test_route_repr() -> None:
     route = Route("/welcome", endpoint=homepage)
-    assert (
-        repr(route)
-        == "Route(path='/welcome', name='homepage', methods=['GET', 'HEAD'])"
-    )
+    assert repr(route) == "Route(path='/welcome', name='homepage', methods=['GET', 'HEAD'])"
 
 
 def test_route_repr_without_methods() -> None:
@@ -1264,9 +1201,7 @@ async def echo_paths(request: Request, name: str) -> JSONResponse:
     )
 
 
-async def pure_asgi_echo_paths(
-    scope: Scope, receive: Receive, send: Send, name: str
-) -> None:
+async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name: str) -> None:
     data = {"name": name, "path": scope["path"], "root_path": scope["root_path"]}
     content = json.dumps(data).encode("utf-8")
     await send(
@@ -1304,9 +1239,7 @@ echo_paths_routes = [
 
 def test_paths_with_root_path(test_client_factory: TestClientFactory) -> None:
     app = Starlette(routes=echo_paths_routes)
-    client = test_client_factory(
-        app, base_url="https://www.example.org/", root_path="/root"
-    )
+    client = test_client_factory(app, base_url="https://www.example.org/", root_path="/root")
     response = client.get("/root/path")
     assert response.status_code == 200
     assert response.json() == {
index f4a5b4ad9950b53dab5500cae94a4732da2a19e4..3b321ca0b195041bde1040c3c9f3390027f21975 100644 (file)
@@ -7,9 +7,7 @@ from starlette.schemas import SchemaGenerator
 from starlette.websockets import WebSocket
 from tests.types import TestClientFactory
 
-schemas = SchemaGenerator(
-    {"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}}
-)
+schemas = SchemaGenerator({"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}})
 
 
 def ws(session: WebSocket) -> None:
@@ -142,7 +140,7 @@ def test_schema_generation() -> None:
                 "get": {
                     "responses": {
                         200: {
-                            "description": "A list of " "organisations.",
+                            "description": "A list of organisations.",
                             "examples": [{"name": "Foo Corp."}, {"name": "Acme Ltd."}],
                         }
                     }
@@ -157,25 +155,13 @@ def test_schema_generation() -> None:
                 },
             },
             "/regular-docstring-and-schema": {
-                "get": {
-                    "responses": {
-                        200: {"description": "This is included in the schema."}
-                    }
-                }
+                "get": {"responses": {200: {"description": "This is included in the schema."}}}
             },
             "/subapp/subapp-endpoint": {
-                "get": {
-                    "responses": {
-                        200: {"description": "This endpoint is part of a subapp."}
-                    }
-                }
+                "get": {"responses": {200: {"description": "This endpoint is part of a subapp."}}}
             },
             "/subapp2/subapp-endpoint": {
-                "get": {
-                    "responses": {
-                        200: {"description": "This endpoint is part of a subapp."}
-                    }
-                }
+                "get": {"responses": {200: {"description": "This endpoint is part of a subapp."}}}
             },
             "/users": {
                 "get": {
@@ -186,11 +172,7 @@ def test_schema_generation() -> None:
                         }
                     }
                 },
-                "post": {
-                    "responses": {
-                        200: {"description": "A user.", "examples": {"username": "tom"}}
-                    }
-                },
+                "post": {"responses": {200: {"description": "A user.", "examples": {"username": "tom"}}}},
             },
             "/users/{id}": {
                 "get": {
index 65d71b97b8984bd99f346bf7238b144353c3483f..8beb3cd8716d210bbe3bd6cf39bb53ae4240cac6 100644 (file)
@@ -31,9 +31,7 @@ def test_staticfiles(tmpdir: Path, test_client_factory: TestClientFactory) -> No
     assert response.text == "<file content>"
 
 
-def test_staticfiles_with_pathlib(
-    tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_with_pathlib(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
     path = tmp_path / "example.txt"
     with open(path, "w") as file:
         file.write("<file content>")
@@ -45,9 +43,7 @@ def test_staticfiles_with_pathlib(
     assert response.text == "<file content>"
 
 
-def test_staticfiles_head_with_middleware(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_head_with_middleware(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     """
     see https://github.com/encode/starlette/pull/935
     """
@@ -55,9 +51,7 @@ def test_staticfiles_head_with_middleware(
     with open(path, "w") as file:
         file.write("x" * 100)
 
-    async def does_nothing_middleware(
-        request: Request, call_next: RequestResponseEndpoint
-    ) -> Response:
+    async def does_nothing_middleware(request: Request, call_next: RequestResponseEndpoint) -> Response:
         response = await call_next(request)
         return response
 
@@ -99,9 +93,7 @@ def test_staticfiles_post(tmpdir: Path, test_client_factory: TestClientFactory)
     assert response.text == "Method Not Allowed"
 
 
-def test_staticfiles_with_directory_returns_404(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_with_directory_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "example.txt")
     with open(path, "w") as file:
         file.write("<file content>")
@@ -115,9 +107,7 @@ def test_staticfiles_with_directory_returns_404(
     assert response.text == "Not Found"
 
 
-def test_staticfiles_with_missing_file_returns_404(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_with_missing_file_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "example.txt")
     with open(path, "w") as file:
         file.write("<file content>")
@@ -138,9 +128,7 @@ def test_staticfiles_instantiated_with_missing_directory(tmpdir: Path) -> None:
     assert "does not exist" in str(exc_info.value)
 
 
-def test_staticfiles_configured_with_missing_directory(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_configured_with_missing_directory(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "no_such_directory")
     app = StaticFiles(directory=path, check_dir=False)
     client = test_client_factory(app)
@@ -163,9 +151,7 @@ def test_staticfiles_configured_with_file_instead_of_directory(
     assert "is not a directory" in str(exc_info.value)
 
 
-def test_staticfiles_config_check_occurs_only_once(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_config_check_occurs_only_once(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     app = StaticFiles(directory=tmpdir)
     client = test_client_factory(app)
     assert not app.config_checked
@@ -199,9 +185,7 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir: Path) -> None:
     assert exc_info.value.detail == "Not Found"
 
 
-def test_staticfiles_never_read_file_for_head_method(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_never_read_file_for_head_method(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "example.txt")
     with open(path, "w") as file:
         file.write("<file content>")
@@ -214,9 +198,7 @@ def test_staticfiles_never_read_file_for_head_method(
     assert response.headers["content-length"] == "14"
 
 
-def test_staticfiles_304_with_etag_match(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_304_with_etag_match(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "example.txt")
     with open(path, "w") as file:
         file.write("<file content>")
@@ -229,9 +211,7 @@ def test_staticfiles_304_with_etag_match(
     second_resp = client.get("/example.txt", headers={"if-none-match": last_etag})
     assert second_resp.status_code == 304
     assert second_resp.content == b""
-    second_resp = client.get(
-        "/example.txt", headers={"if-none-match": f'W/{last_etag}, "123"'}
-    )
+    second_resp = client.get("/example.txt", headers={"if-none-match": f'W/{last_etag}, "123"'})
     assert second_resp.status_code == 304
     assert second_resp.content == b""
 
@@ -240,9 +220,7 @@ def test_staticfiles_304_with_last_modified_compare_last_req(
     tmpdir: Path, test_client_factory: TestClientFactory
 ) -> None:
     path = os.path.join(tmpdir, "example.txt")
-    file_last_modified_time = time.mktime(
-        time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S")
-    )
+    file_last_modified_time = time.mktime(time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S"))
     with open(path, "w") as file:
         file.write("<file content>")
     os.utime(path, (file_last_modified_time, file_last_modified_time))
@@ -250,22 +228,16 @@ def test_staticfiles_304_with_last_modified_compare_last_req(
     app = StaticFiles(directory=tmpdir)
     client = test_client_factory(app)
     # last modified less than last request, 304
-    response = client.get(
-        "/example.txt", headers={"If-Modified-Since": "Thu, 11 Oct 2013 15:30:19 GMT"}
-    )
+    response = client.get("/example.txt", headers={"If-Modified-Since": "Thu, 11 Oct 2013 15:30:19 GMT"})
     assert response.status_code == 304
     assert response.content == b""
     # last modified greater than last request, 200 with content
-    response = client.get(
-        "/example.txt", headers={"If-Modified-Since": "Thu, 20 Feb 2012 15:30:19 GMT"}
-    )
+    response = client.get("/example.txt", headers={"If-Modified-Since": "Thu, 20 Feb 2012 15:30:19 GMT"})
     assert response.status_code == 200
     assert response.content == b"<file content>"
 
 
-def test_staticfiles_html_normal(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_html_normal(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "404.html")
     with open(path, "w") as file:
         file.write("<h1>Custom not found page</h1>")
@@ -298,9 +270,7 @@ def test_staticfiles_html_normal(
     assert response.text == "<h1>Custom not found page</h1>"
 
 
-def test_staticfiles_html_without_index(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_html_without_index(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "404.html")
     with open(path, "w") as file:
         file.write("<h1>Custom not found page</h1>")
@@ -325,9 +295,7 @@ def test_staticfiles_html_without_index(
     assert response.text == "<h1>Custom not found page</h1>"
 
 
-def test_staticfiles_html_without_404(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_html_without_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "dir")
     os.mkdir(path)
     path = os.path.join(path, "index.html")
@@ -352,9 +320,7 @@ def test_staticfiles_html_without_404(
     assert exc_info.value.status_code == 404
 
 
-def test_staticfiles_html_only_files(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_html_only_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "hello.html")
     with open(path, "w") as file:
         file.write("<h1>Hello</h1>")
@@ -381,9 +347,7 @@ def test_staticfiles_cache_invalidation_for_deleted_file_html_mode(
     with open(path_some, "w") as file:
         file.write("<p>some file</p>")
 
-    common_modified_time = time.mktime(
-        time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S")
-    )
+    common_modified_time = time.mktime(time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S"))
     os.utime(path_404, (common_modified_time, common_modified_time))
     os.utime(path_some, (common_modified_time, common_modified_time))
 
@@ -435,9 +399,7 @@ def test_staticfiles_with_invalid_dir_permissions_returns_401(
         tmp_path.chmod(original_mode)
 
 
-def test_staticfiles_with_missing_dir_returns_404(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_with_missing_dir_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "example.txt")
     with open(path, "w") as file:
         file.write("<file content>")
@@ -451,9 +413,7 @@ def test_staticfiles_with_missing_dir_returns_404(
     assert response.text == "Not Found"
 
 
-def test_staticfiles_access_file_as_dir_returns_404(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_access_file_as_dir_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "example.txt")
     with open(path, "w") as file:
         file.write("<file content>")
@@ -467,9 +427,7 @@ def test_staticfiles_access_file_as_dir_returns_404(
     assert response.text == "Not Found"
 
 
-def test_staticfiles_filename_too_long(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_filename_too_long(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")]
     app = Starlette(routes=routes)
     client = test_client_factory(app)
@@ -503,9 +461,7 @@ def test_staticfiles_unhandled_os_error_returns_500(
     assert response.text == "Internal Server Error"
 
 
-def test_staticfiles_follows_symlinks(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_follows_symlinks(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     statics_path = os.path.join(tmpdir, "statics")
     os.mkdir(statics_path)
 
@@ -526,9 +482,7 @@ def test_staticfiles_follows_symlinks(
     assert response.text == "<h1>Hello</h1>"
 
 
-def test_staticfiles_follows_symlink_directories(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_follows_symlink_directories(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     statics_path = os.path.join(tmpdir, "statics")
     statics_html_path = os.path.join(statics_path, "html")
     os.mkdir(statics_path)
index 04719e87ed88ce710f960938ed442b206122c9b6..4852c06ef976ae3dca2b6b129bd8f591db89f3ed 100644 (file)
@@ -8,13 +8,11 @@ import pytest
     (
         (
             "WS_1004_NO_STATUS_RCVD",
-            "'WS_1004_NO_STATUS_RCVD' is deprecated. "
-            "Use 'WS_1005_NO_STATUS_RCVD' instead.",
+            "'WS_1004_NO_STATUS_RCVD' is deprecated. Use 'WS_1005_NO_STATUS_RCVD' instead.",
         ),
         (
             "WS_1005_ABNORMAL_CLOSURE",
-            "'WS_1005_ABNORMAL_CLOSURE' is deprecated. "
-            "Use 'WS_1006_ABNORMAL_CLOSURE' instead.",
+            "'WS_1005_ABNORMAL_CLOSURE' is deprecated. Use 'WS_1006_ABNORMAL_CLOSURE' instead.",
         ),
     ),
 )
index 8e344f331f27b66df6d2854e24af21e1b5eb9643..6b2080c17793877d7b0cbe5e8c9f5ef83abb0086 100644 (file)
@@ -36,9 +36,7 @@ def test_templates(tmpdir: Path, test_client_factory: TestClientFactory) -> None
     assert set(response.context.keys()) == {"request"}  # type: ignore
 
 
-def test_calls_context_processors(
-    tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_calls_context_processors(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
     path = tmp_path / "index.html"
     path.write_text("<html>Hello {{ username }}</html>")
 
@@ -66,9 +64,7 @@ def test_calls_context_processors(
     assert set(response.context.keys()) == {"request", "username"}  # type: ignore
 
 
-def test_template_with_middleware(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_template_with_middleware(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "index.html")
     with open(path, "w") as file:
         file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>")
@@ -77,9 +73,7 @@ def test_template_with_middleware(
         return templates.TemplateResponse(request, "index.html")
 
     class CustomMiddleware(BaseHTTPMiddleware):
-        async def dispatch(
-            self, request: Request, call_next: RequestResponseEndpoint
-        ) -> Response:
+        async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
             return await call_next(request)
 
     app = Starlette(
@@ -96,9 +90,7 @@ def test_template_with_middleware(
     assert set(response.context.keys()) == {"request"}  # type: ignore
 
 
-def test_templates_with_directories(
-    tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_templates_with_directories(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
     dir_a = tmp_path.resolve() / "a"
     dir_a.mkdir()
     template_a = dir_a / "template_a.html"
@@ -134,16 +126,12 @@ def test_templates_with_directories(
 
 
 def test_templates_require_directory_or_environment() -> None:
-    with pytest.raises(
-        AssertionError, match="either 'directory' or 'env' arguments must be passed"
-    ):
+    with pytest.raises(AssertionError, match="either 'directory' or 'env' arguments must be passed"):
         Jinja2Templates()  # type: ignore[call-overload]
 
 
 def test_templates_require_directory_or_enviroment_not_both() -> None:
-    with pytest.raises(
-        AssertionError, match="either 'directory' or 'env' arguments must be passed"
-    ):
+    with pytest.raises(AssertionError, match="either 'directory' or 'env' arguments must be passed"):
         Jinja2Templates(directory="dir", env=jinja2.Environment())
 
 
@@ -157,9 +145,7 @@ def test_templates_with_directory(tmpdir: Path) -> None:
     assert template.render({}) == "Hello"
 
 
-def test_templates_with_environment(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_templates_with_environment(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     path = os.path.join(tmpdir, "index.html")
     with open(path, "w") as file:
         file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>")
@@ -185,9 +171,7 @@ def test_templates_with_environment_options_emit_warning(tmpdir: Path) -> None:
         Jinja2Templates(str(tmpdir), autoescape=True)
 
 
-def test_templates_with_kwargs_only(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_templates_with_kwargs_only(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     # MAINTAINERS: remove after 1.0
     path = os.path.join(tmpdir, "index.html")
     with open(path, "w") as file:
@@ -242,9 +226,7 @@ def test_templates_with_kwargs_only_warns_when_no_request_keyword(
     templates = Jinja2Templates(directory=str(tmpdir))
 
     def page(request: Request) -> Response:
-        return templates.TemplateResponse(
-            name="index.html", context={"request": request}
-        )
+        return templates.TemplateResponse(name="index.html", context={"request": request})
 
     app = Starlette(routes=[Route("/", page)])
     client = test_client_factory(app)
@@ -297,9 +279,7 @@ def test_templates_warns_when_first_argument_isnot_request(
     spy.assert_called()
 
 
-def test_templates_when_first_argument_is_request(
-    tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_templates_when_first_argument_is_request(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     # MAINTAINERS: remove after 1.0
     path = os.path.join(tmpdir, "index.html")
     with open(path, "w") as file:
index 77de3d976fcddd9a947bf135c904cb0e2ae0dfe1..92f16d33649fd76914654ac25358e7712b38077d 100644 (file)
@@ -88,9 +88,7 @@ def test_testclient_headers_behavior() -> None:
     assert client.headers.get("Authentication") == "Bearer 123"
 
 
-def test_use_testclient_as_contextmanager(
-    test_client_factory: TestClientFactory, anyio_backend_name: str
-) -> None:
+def test_use_testclient_as_contextmanager(test_client_factory: TestClientFactory, anyio_backend_name: str) -> None:
     """
     This test asserts a number of properties that are important for an
     app level task_group
@@ -169,9 +167,7 @@ def test_use_testclient_as_contextmanager(
 
 
 def test_error_on_startup(test_client_factory: TestClientFactory) -> None:
-    with pytest.deprecated_call(
-        match="The on_startup and on_shutdown parameters are deprecated"
-    ):
+    with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
         startup_error_app = Starlette(on_startup=[startup])
 
     with pytest.raises(RuntimeError):
@@ -306,8 +302,7 @@ def test_query_params(test_client_factory: TestClientFactory, param: str) -> Non
             marks=[
                 pytest.mark.xfail(
                     sys.version_info < (3, 11),
-                    reason="Fails due to domain handling in http.cookiejar module (see "
-                    "#2152)",
+                    reason="Fails due to domain handling in http.cookiejar module (see #2152)",
                 ),
             ],
         ),
@@ -316,9 +311,7 @@ def test_query_params(test_client_factory: TestClientFactory, param: str) -> Non
         ("example.com", False),
     ],
 )
-def test_domain_restricted_cookies(
-    test_client_factory: TestClientFactory, domain: str, ok: bool
-) -> None:
+def test_domain_restricted_cookies(test_client_factory: TestClientFactory, domain: str, ok: bool) -> None:
     """
     Test that test client discards domain restricted cookies which do not match the
     base_url of the testclient (`http://testserver` by default).
index 16d2d0f1f884a748ad2f8a6b99aed579d53185ac..7a9b9272aba9278f9f04c8866657706b9c20faaa 100644 (file)
@@ -270,8 +270,7 @@ async def test_client_disconnect_on_send() -> None:
     async def send(message: Message) -> None:
         if message["type"] == "websocket.accept":
             return
-        # Simulate the exception the server would send to the application when the
-        # client disconnects.
+        # Simulate the exception the server would send to the application when the client disconnects.
         raise OSError
 
     with pytest.raises(WebSocketDisconnect) as ctx:
@@ -334,19 +333,8 @@ def test_send_response_multi(test_client_factory: TestClientFactory) -> None:
                 "headers": [(b"content-type", b"text/plain"), (b"foo", b"bar")],
             }
         )
-        await websocket.send(
-            {
-                "type": "websocket.http.response.body",
-                "body": b"hard",
-                "more_body": True,
-            }
-        )
-        await websocket.send(
-            {
-                "type": "websocket.http.response.body",
-                "body": b"body",
-            }
-        )
+        await websocket.send({"type": "websocket.http.response.body", "body": b"hard", "more_body": True})
+        await websocket.send({"type": "websocket.http.response.body", "body": b"body"})
 
     client = test_client_factory(app)
     with pytest.raises(WebSocketDenialResponse) as exc:
@@ -402,10 +390,7 @@ def test_send_response_duplicate_start(test_client_factory: TestClientFactory) -
     client = test_client_factory(app)
     with pytest.raises(
         RuntimeError,
-        match=(
-            'Expected ASGI message "websocket.http.response.body", but got '
-            "'websocket.http.response.start'"
-        ),
+        match=("Expected ASGI message \"websocket.http.response.body\", but got 'websocket.http.response.start'"),
     ):
         with client.websocket_connect("/"):
             pass  # pragma: no cover
@@ -493,11 +478,7 @@ def test_websocket_scope_interface() -> None:
 
     async def mock_send(message: Message) -> None: ...  # pragma: no cover
 
-    websocket = WebSocket(
-        {"type": "websocket", "path": "/abc/", "headers": []},
-        receive=mock_receive,
-        send=mock_send,
-    )
+    websocket = WebSocket({"type": "websocket", "path": "/abc/", "headers": []}, receive=mock_receive, send=mock_send)
     assert websocket["type"] == "websocket"
     assert dict(websocket) == {"type": "websocket", "path": "/abc/", "headers": []}
     assert len(websocket) == 3