[tool.hatch.version]
path = "starlette/__init__.py"
+[tool.ruff]
+line-length = 120
+
[tool.ruff.lint]
-select = ["E", "F", "I", "FA", "UP"]
-ignore = ["UP031"]
+select = [
+ "E", # https://docs.astral.sh/ruff/rules/#error-e
+ "F", # https://docs.astral.sh/ruff/rules/#pyflakes-f
+ "I", # https://docs.astral.sh/ruff/rules/#isort-i
+ "FA", # https://docs.astral.sh/ruff/rules/#flake8-future-annotations-fa
+ "UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up
+ "RUF100", # https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
+]
+ignore = ["UP031"] # https://docs.astral.sh/ruff/rules/printf-string-formatting/
[tool.ruff.lint.isort]
combine-as-imports = true
# that reject usedforsecurity=True
hashlib.md5(b"data", usedforsecurity=False) # type: ignore[call-arg]
- def md5_hexdigest(
- data: bytes, *, usedforsecurity: bool = True
- ) -> str: # pragma: no cover
+ def md5_hexdigest(data: bytes, *, usedforsecurity: bool = True) -> str: # pragma: no cover
return hashlib.md5( # type: ignore[call-arg]
data, usedforsecurity=usedforsecurity
).hexdigest()
StatusHandlers = typing.Dict[int, ExceptionHandler]
-def _lookup_exception_handler(
- exc_handlers: ExceptionHandlers, exc: Exception
-) -> ExceptionHandler | None:
+def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None:
for cls in type(exc).__mro__:
if cls in exc_handlers:
return exc_handlers[cls]
while isinstance(obj, functools.partial):
obj = obj.func
- return asyncio.iscoroutinefunction(obj) or (
- callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
- )
+ return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__))
T_co = typing.TypeVar("T_co", covariant=True)
-class AwaitableOrContextManager(
- typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]
-): ...
+class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]): ...
class SupportsAsyncClose(typing.Protocol):
async def close(self) -> None: ... # pragma: no cover
-SupportsAsyncCloseType = typing.TypeVar(
- "SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False
-)
+SupportsAsyncCloseType = typing.TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
self.debug = debug
self.state = State()
- self.router = Router(
- routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan
- )
- self.exception_handlers = (
- {} if exception_handlers is None else dict(exception_handlers)
- )
+ self.router = Router(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)
+ self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)
self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack: ASGIApp | None = None
def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
error_handler = None
- exception_handlers: dict[
- typing.Any, typing.Callable[[Request, Exception], Response]
- ] = {}
+ exception_handlers: dict[typing.Any, typing.Callable[[Request, Exception], Response]] = {}
for key, value in self.exception_handlers.items():
if key in (500, Exception):
middleware = (
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
+ self.user_middleware
- + [
- Middleware(
- ExceptionMiddleware, handlers=exception_handlers, debug=debug
- )
- ]
+ + [Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug)]
)
app = self.router
name: str | None = None,
include_in_schema: bool = True,
) -> None: # pragma: no cover
- self.router.add_route(
- path, route, methods=methods, name=name, include_in_schema=include_in_schema
- )
+ self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema)
def add_websocket_route(
self,
) -> None: # pragma: no cover
self.router.add_websocket_route(path, route, name=name)
- def exception_handler(
- self, exc_class_or_status_code: int | type[Exception]
- ) -> typing.Callable: # type: ignore[type-arg]
+ def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> typing.Callable: # type: ignore[type-arg]
warnings.warn(
- "The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
- "Refer to https://www.starlette.io/exceptions/ for the recommended approach.", # noqa: E501
+ "The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. "
+ "Refer to https://www.starlette.io/exceptions/ for the recommended approach.",
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_exception_handler(exc_class_or_status_code, func)
return func
>>> app = Starlette(routes=routes)
"""
warnings.warn(
- "The `route` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
- "Refer to https://www.starlette.io/routing/ for the recommended approach.", # noqa: E501
+ "The `route` decorator is deprecated, and will be removed in version 1.0.0. "
+ "Refer to https://www.starlette.io/routing/ for the recommended approach.",
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.router.add_route(
path,
func,
>>> app = Starlette(routes=routes)
"""
warnings.warn(
- "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
- "Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501
+ "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. "
+ "Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.router.add_websocket_route(path, func, name=name)
return func
return decorator
- def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg]
"""
We no longer document this decorator style API, and its usage is discouraged.
Instead you should use the following approach:
>>> app = Starlette(middleware=middleware)
"""
warnings.warn(
- "The `middleware` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
- "Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.", # noqa: E501
+ "The `middleware` decorator is deprecated, and will be removed in version 1.0.0. "
+ "Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.",
DeprecationWarning,
)
- assert (
- middleware_type == "http"
- ), 'Currently only middleware("http") is supported.'
+ assert middleware_type == "http", 'Currently only middleware("http") is supported.'
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_middleware(BaseHTTPMiddleware, dispatch=func)
return func
scopes: str | typing.Sequence[str],
status_code: int = 403,
redirect: str | None = None,
-) -> typing.Callable[
- [typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]
-]:
+) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]:
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
def decorator(
type_ = parameter.name
break
else:
- raise Exception(
- f'No "request" or "websocket" argument on function "{func}"'
- )
+ raise Exception(f'No "request" or "websocket" argument on function "{func}"')
if type_ == "websocket":
# Handle websocket functions. (Always async)
@functools.wraps(func)
async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
- websocket = kwargs.get(
- "websocket", args[idx] if idx < len(args) else None
- )
+ websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
assert isinstance(websocket, WebSocket)
if not has_required_scope(websocket, scopes_list):
class AuthenticationBackend:
- async def authenticate(
- self, conn: HTTPConnection
- ) -> tuple[AuthCredentials, BaseUser] | None:
+ async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
raise NotImplementedError() # pragma: no cover
class BackgroundTask:
- def __init__(
- self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
- ) -> None:
+ def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
self.func = func
self.args = args
self.kwargs = kwargs
def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None):
self.tasks = list(tasks) if tasks else []
- def add_task(
- self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
- ) -> None:
+ def add_task(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
task = BackgroundTask(func, *args, **kwargs)
self.tasks.append(task)
T = typing.TypeVar("T")
-async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg] # noqa: E501
+async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg]
warnings.warn(
- "run_until_first_complete is deprecated "
- "and will be removed in a future version.",
+ "run_until_first_complete is deprecated and will be removed in a future version.",
DeprecationWarning,
)
async with anyio.create_task_group() as task_group:
- async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg] # noqa: E501
+ async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg]
await func()
task_group.cancel_scope.cancel()
task_group.start_soon(run, functools.partial(func, **kwargs))
-async def run_in_threadpool(
- func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
-) -> T:
+async def run_in_threadpool(func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
if kwargs: # pragma: no cover
# run_sync doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
def __setitem__(self, key: str, value: str) -> None:
if key in self._has_been_read:
- raise EnvironError(
- f"Attempting to set environ['{key}'], but the value has already been "
- "read."
- )
+ raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been read.")
self._environ.__setitem__(key, value)
def __delitem__(self, key: str) -> None:
if key in self._has_been_read:
- raise EnvironError(
- f"Attempting to delete environ['{key}'], but the value has already "
- "been read."
- )
+ raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.")
self._environ.__delitem__(key)
def __iter__(self) -> typing.Iterator[str]:
) -> T: ...
@typing.overload
- def __call__(
- self, key: str, cast: type[str] = ..., default: T = ...
- ) -> T | str: ...
+ def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ...
def __call__(
self,
mapping = {"true": True, "1": True, "false": False, "0": False}
value = value.lower()
if value not in mapping:
- raise ValueError(
- f"Config '{key}' has value '{value}'. Not a valid bool."
- )
+ raise ValueError(f"Config '{key}' has value '{value}'. Not a valid bool.")
return mapping[value]
try:
return cast(value)
except (TypeError, ValueError):
- raise ValueError(
- f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}."
- )
+ raise ValueError(f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}.")
return self.scheme in ("https", "wss")
def replace(self, **kwargs: typing.Any) -> URL:
- if (
- "username" in kwargs
- or "password" in kwargs
- or "hostname" in kwargs
- or "port" in kwargs
- ):
+ if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
hostname = kwargs.pop("hostname", None)
port = kwargs.pop("port", self.port)
username = kwargs.pop("username", self.username)
value: typing.Any = args[0] if args else []
if kwargs:
- value = (
- ImmutableMultiDict(value).multi_items()
- + ImmutableMultiDict(kwargs).multi_items()
- )
+ value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()
if not value:
_items: list[tuple[typing.Any, typing.Any]] = []
elif hasattr(value, "multi_items"):
- value = typing.cast(
- ImmutableMultiDict[_KeyType, _CovariantValueType], value
- )
+ value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
_items = list(value.multi_items())
elif hasattr(value, "items"):
value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
def update(
self,
- *args: MultiDict
- | typing.Mapping[typing.Any, typing.Any]
- | list[tuple[typing.Any, typing.Any]],
+ *args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]],
**kwargs: typing.Any,
) -> None:
value = MultiDict(*args, **kwargs)
if isinstance(value, str):
super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
elif isinstance(value, bytes):
- super().__init__(
- parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs
- )
+ super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs)
else:
super().__init__(*args, **kwargs) # type: ignore[arg-type]
self._list = [(str(k), str(v)) for k, v in self._list]
def __init__(
self,
- *args: FormData
- | typing.Mapping[str, str | UploadFile]
- | list[tuple[str, str | UploadFile]],
+ *args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]],
**kwargs: str | UploadFile,
) -> None:
super().__init__(*args, **kwargs)
if headers is not None:
assert raw is None, 'Cannot set both "headers" and "raw".'
assert scope is None, 'Cannot set both "headers" and "scope".'
- self._list = [
- (key.lower().encode("latin-1"), value.encode("latin-1"))
- for key, value in headers.items()
- ]
+ self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()]
elif raw is not None:
assert scope is None, 'Cannot set both "raw" and "scope".'
self._list = raw
return [value.decode("latin-1") for key, value in self._list]
def items(self) -> list[tuple[str, str]]: # type: ignore[override]
- return [
- (key.decode("latin-1"), value.decode("latin-1"))
- for key, value in self._list
- ]
+ return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list]
def getlist(self, key: str) -> list[str]:
get_header_key = key.lower().encode("latin-1")
- return [
- item_value.decode("latin-1")
- for item_key, item_value in self._list
- if item_key == get_header_key
- ]
+ return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key]
def mutablecopy(self) -> MutableHeaders:
return MutableHeaders(raw=self._list[:])
async def dispatch(self) -> None:
request = Request(self.scope, receive=self.receive)
- handler_name = (
- "get"
- if request.method == "HEAD" and not hasattr(self, "head")
- else request.method.lower()
- )
-
- handler: typing.Callable[[Request], typing.Any] = getattr(
- self, handler_name, self.method_not_allowed
- )
+ handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
+
+ handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed)
is_async = is_async_callable(handler)
if is_async:
response = await handler(request)
data = await self.decode(websocket, message)
await self.on_receive(websocket, data)
elif message["type"] == "websocket.disconnect":
- close_code = int(
- message.get("code") or status.WS_1000_NORMAL_CLOSURE
- )
+ close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE)
break
except Exception as exc:
close_code = status.WS_1011_INTERNAL_ERROR
await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
raise RuntimeError("Malformed JSON data received.")
- assert (
- self.encoding is None
- ), f"Unsupported 'encoding' attribute {self.encoding}"
+ assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}"
return message["text"] if message.get("text") else message["bytes"]
async def on_connect(self, websocket: WebSocket) -> None:
class FormParser:
- def __init__(
- self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
- ) -> None:
- assert (
- multipart is not None
- ), "The `python-multipart` library must be installed to use form parsing."
+ def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None:
+ assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.messages: list[tuple[FormMessage, bytes]] = []
max_files: int | float = 1000,
max_fields: int | float = 1000,
) -> None:
- assert (
- multipart is not None
- ), "The `python-multipart` library must be installed to use form parsing."
+ assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.max_files = max_files
field = self._current_partial_header_name.lower()
if field == b"content-disposition":
self._current_part.content_disposition = self._current_partial_header_value
- self._current_part.item_headers.append(
- (field, self._current_partial_header_value)
- )
+ self._current_part.item_headers.append((field, self._current_partial_header_value))
self._current_partial_header_name = b""
self._current_partial_header_value = b""
def on_headers_finished(self) -> None:
- disposition, options = parse_options_header(
- self._current_part.content_disposition
- )
+ disposition, options = parse_options_header(self._current_part.content_disposition)
try:
- self._current_part.field_name = _user_safe_decode(
- options[b"name"], self._charset
- )
+ self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
except KeyError:
- raise MultiPartException(
- 'The Content-Disposition header field "name" must be ' "provided."
- )
+ raise MultiPartException('The Content-Disposition header field "name" must be provided.')
if b"filename" in options:
self._current_files += 1
if self._current_files > self.max_files:
- raise MultiPartException(
- f"Too many files. Maximum number of files is {self.max_files}."
- )
+ raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
filename = _user_safe_decode(options[b"filename"], self._charset)
tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
self._files_to_close_on_error.append(tempfile)
else:
self._current_fields += 1
if self._current_fields > self.max_fields:
- raise MultiPartException(
- f"Too many fields. Maximum number of fields is {self.max_fields}."
- )
+ raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.")
self._current_part.file = None
def on_end(self) -> None:
class _MiddlewareClass(Protocol[P]):
- def __init__(
- self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs
- ) -> None: ... # pragma: no cover
+ def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None: ... # pragma: no cover
- async def __call__(
- self, scope: Scope, receive: Receive, send: Send
- ) -> None: ... # pragma: no cover
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ... # pragma: no cover
class Middleware:
self,
app: ASGIApp,
backend: AuthenticationBackend,
- on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response]
- | None = None,
+ on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
) -> None:
self.app = app
self.backend = backend
- self.on_error: typing.Callable[
- [HTTPConnection, AuthenticationError], Response
- ] = on_error if on_error is not None else self.default_on_error
+ self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = (
+ on_error if on_error is not None else self.default_on_error
+ )
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ["http", "websocket"]:
from starlette.types import ASGIApp, Message, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
-DispatchFunction = typing.Callable[
- [Request, RequestResponseEndpoint], typing.Awaitable[Response]
-]
+DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]]
T = typing.TypeVar("T")
if app_exc is not None:
raise app_exc
- response = _StreamingResponse(
- status_code=message["status"], content=body_stream(), info=info
- )
+ response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
response.raw_headers = message["headers"]
return response
await response(scope, wrapped_receive, send)
response_sent.set()
- async def dispatch(
- self, request: Request, call_next: RequestResponseEndpoint
- ) -> Response:
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
raise NotImplementedError() # pragma: no cover
if self.allow_all_origins:
return True
- if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
- origin
- ):
+ if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin):
return True
return origin in self.allow_origins
return PlainTextResponse("OK", status_code=200, headers=headers)
- async def simple_response(
- self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
- ) -> None:
+ async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None:
send = functools.partial(self.send, send=send, request_headers=request_headers)
await self.app(scope, receive, send)
- async def send(
- self, message: Message, send: Send, request_headers: Headers
- ) -> None:
+ async def send(self, message: Message, send: Send, request_headers: Headers) -> None:
if message["type"] != "http.response.start":
await send(message)
return
# to optionally raise the error within the test case.
raise exc
- def format_line(
- self, index: int, line: str, frame_lineno: int, frame_index: int
- ) -> str:
+ def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str:
values = {
# HTML escape - line could contain < or >
"line": html.escape(line).replace(" ", " "),
return FRAME_TEMPLATE.format(**values)
def generate_html(self, exc: Exception, limit: int = 7) -> str:
- traceback_obj = traceback.TracebackException.from_exception(
- exc, capture_locals=True
- )
+ traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True)
exc_html = ""
is_collapsed = False
def __init__(
self,
app: ASGIApp,
- handlers: typing.Mapping[
- typing.Any, typing.Callable[[Request, Exception], Response]
- ]
- | None = None,
+ handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None,
debug: bool = False,
) -> None:
self.app = app
assert isinstance(exc, HTTPException)
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
- return PlainTextResponse(
- exc.detail, status_code=exc.status_code, headers=exc.headers
- )
+ return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
assert isinstance(exc, WebSocketException)
class GZipMiddleware:
- def __init__(
- self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
- ) -> None:
+ def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
self.app = app
self.minimum_size = minimum_size
self.compresslevel = compresslevel
if scope["type"] == "http":
headers = Headers(scope=scope)
if "gzip" in headers.get("Accept-Encoding", ""):
- responder = GZipResponder(
- self.app, self.minimum_size, compresslevel=self.compresslevel
- )
+ responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
await responder(scope, receive, send)
return
await self.app(scope, receive, send)
self.started = False
self.content_encoding_set = False
self.gzip_buffer = io.BytesIO()
- self.gzip_file = gzip.GzipFile(
- mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel
- )
+ self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self.send = send
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
data = self.signer.sign(data)
headers = MutableHeaders(scope=message)
- header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501
+ header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format(
session_cookie=self.session_cookie,
data=data.decode("utf-8"),
path=self.path,
elif not initial_session_was_empty:
# The session has been cleared.
headers = MutableHeaders(scope=message)
- header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( # noqa E501
+ header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format(
session_cookie=self.session_cookie,
data="null",
path=self.path,
is_valid_host = False
found_www_redirect = False
for pattern in self.allowed_hosts:
- if host == pattern or (
- pattern.startswith("*") and host.endswith(pattern[1:])
- ):
+ if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])):
is_valid_host = True
break
elif "www." + host == pattern:
self.scope = scope
self.status = None
self.response_headers = None
- self.stream_send, self.stream_receive = anyio.create_memory_object_stream(
- math.inf
- )
+ self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf)
self.response_started = False
self.exc_info: typing.Any = None
{"type": "http.response.body", "body": chunk, "more_body": True},
)
- anyio.from_thread.run(
- self.stream_send.send, {"type": "http.response.body", "body": b""}
- )
+ anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""})
# This is used by request.url_for, it might be used inside a Mount which
# would have its own child scope with its own root_path, but the base URL
# for url_for should still be the top level app root path.
- app_root_path = base_url_scope.get(
- "app_root_path", base_url_scope.get("root_path", "")
- )
+ app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", ""))
path = app_root_path
if not path.endswith("/"):
path += "/"
@property
def session(self) -> dict[str, typing.Any]:
- assert (
- "session" in self.scope
- ), "SessionMiddleware must be installed to access request.session"
+ assert "session" in self.scope, "SessionMiddleware must be installed to access request.session"
return self.scope["session"] # type: ignore[no-any-return]
@property
def auth(self) -> typing.Any:
- assert (
- "auth" in self.scope
- ), "AuthenticationMiddleware must be installed to access request.auth"
+ assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth"
return self.scope["auth"]
@property
def user(self) -> typing.Any:
- assert (
- "user" in self.scope
- ), "AuthenticationMiddleware must be installed to access request.user"
+ assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user"
return self.scope["user"]
@property
class Request(HTTPConnection):
_form: FormData | None
- def __init__(
- self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
- ):
+ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send):
super().__init__(scope)
assert scope["type"] == "http"
self._receive = receive
self._json = json.loads(body)
return self._json
- async def _get_form(
- self, *, max_files: int | float = 1000, max_fields: int | float = 1000
- ) -> FormData:
+ async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | float = 1000) -> FormData:
if self._form is None:
assert (
parse_options_header is not None
def form(
self, *, max_files: int | float = 1000, max_fields: int | float = 1000
) -> AwaitableOrContextManager[FormData]:
- return AwaitableOrContextManagerWrapper(
- self._get_form(max_files=max_files, max_fields=max_fields)
- )
+ return AwaitableOrContextManagerWrapper(self._get_form(max_files=max_files, max_fields=max_fields))
async def close(self) -> None:
if self._form is not None:
raw_headers: list[tuple[bytes, bytes]] = []
for name in SERVER_PUSH_HEADERS_TO_COPY:
for value in self.headers.getlist(name):
- raw_headers.append(
- (name.encode("latin-1"), value.encode("latin-1"))
- )
- await self._send(
- {"type": "http.response.push", "path": path, "headers": raw_headers}
- )
+ raw_headers.append((name.encode("latin-1"), value.encode("latin-1")))
+ await self._send({"type": "http.response.push", "path": path, "headers": raw_headers})
populate_content_length = True
populate_content_type = True
else:
- raw_headers = [
- (k.lower().encode("latin-1"), v.encode("latin-1"))
- for k, v in headers.items()
- ]
+ raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()]
keys = [h[0] for h in raw_headers]
populate_content_length = b"content-length" not in keys
populate_content_type = b"content-type" not in keys
content_type = self.media_type
if content_type is not None and populate_content_type:
- if (
- content_type.startswith("text/")
- and "charset=" not in content_type.lower()
- ):
+ if content_type.startswith("text/") and "charset=" not in content_type.lower():
content_type += "; charset=" + self.charset
raw_headers.append((b"content-type", content_type.encode("latin-1")))
headers: typing.Mapping[str, str] | None = None,
background: BackgroundTask | None = None,
) -> None:
- super().__init__(
- content=b"", status_code=status_code, headers=headers, background=background
- )
+ super().__init__(content=b"", status_code=status_code, headers=headers, background=background)
self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
if self.filename is not None:
content_disposition_filename = quote(self.filename)
if content_disposition_filename != self.filename:
- content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}" # noqa: E501
+ content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}"
else:
- content_disposition = (
- f'{content_disposition_type}; filename="{self.filename}"'
- )
+ content_disposition = f'{content_disposition_type}; filename="{self.filename}"'
self.headers.setdefault("content-disposition", content_disposition)
self.stat_result = stat_result
if stat_result is not None:
including those wrapped in functools.partial objects.
"""
warnings.warn(
- "iscoroutinefunction_or_partial is deprecated, "
- "and will be removed in a future release.",
+ "iscoroutinefunction_or_partial is deprecated, and will be removed in a future release.",
DeprecationWarning,
)
while isinstance(obj, functools.partial):
for match in PARAM_REGEX.finditer(path):
param_name, convertor_type = match.groups("str")
convertor_type = convertor_type.lstrip(":")
- assert (
- convertor_type in CONVERTOR_TYPES
- ), f"Unknown path convertor '{convertor_type}'"
+ assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'"
convertor = CONVERTOR_TYPES[convertor_type]
path_regex += re.escape(path[idx : match.start()])
if name != self.name or seen_params != expected_params:
raise NoMatchFound(name, path_params)
- path, remaining_params = replace_params(
- self.path_format, self.param_convertors, path_params
- )
+ path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
assert not remaining_params
return URLPath(path=path, protocol="http")
if "app" in scope:
raise HTTPException(status_code=405, headers=headers)
else:
- response = PlainTextResponse(
- "Method Not Allowed", status_code=405, headers=headers
- )
+ response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
await response(scope, receive, send)
else:
await self.app(scope, receive, send)
if name != self.name or seen_params != expected_params:
raise NoMatchFound(name, path_params)
- path, remaining_params = replace_params(
- self.path_format, self.param_convertors, path_params
- )
+ path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
assert not remaining_params
return URLPath(path=path, protocol="websocket")
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
- return (
- isinstance(other, WebSocketRoute)
- and self.path == other.path
- and self.endpoint == other.endpoint
- )
+ return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint
def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})"
middleware: typing.Sequence[Middleware] | None = None,
) -> None:
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
- assert (
- app is not None or routes is not None
- ), "Either 'app=...', or 'routes=' must be specified"
+ assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified"
self.path = path.rstrip("/")
if app is not None:
self._base_app: ASGIApp = app
for cls, args, kwargs in reversed(middleware):
self.app = cls(app=self.app, *args, **kwargs)
self.name = name
- self.path_regex, self.path_format, self.param_convertors = compile_path(
- self.path + "/{path:path}"
- )
+ self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")
@property
def routes(self) -> list[BaseRoute]:
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "<mount_name>".
path_params["path"] = path_params["path"].lstrip("/")
- path, remaining_params = replace_params(
- self.path_format, self.param_convertors, path_params
- )
+ path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
if not remaining_params:
return URLPath(path=path)
elif self.name is None or name.startswith(self.name + ":"):
remaining_name = name[len(self.name) + 1 :]
path_kwarg = path_params.get("path")
path_params["path"] = ""
- path_prefix, remaining_params = replace_params(
- self.path_format, self.param_convertors, path_params
- )
+ path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
if path_kwarg is not None:
remaining_params["path"] = path_kwarg
for route in self.routes or []:
try:
url = route.url_path_for(remaining_name, **remaining_params)
- return URLPath(
- path=path_prefix.rstrip("/") + str(url), protocol=url.protocol
- )
+ return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol)
except NoMatchFound:
pass
raise NoMatchFound(name, path_params)
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
- return (
- isinstance(other, Mount)
- and self.path == other.path
- and self.app == other.app
- )
+ return isinstance(other, Mount) and self.path == other.path and self.app == other.app
def __repr__(self) -> str:
class_name = self.__class__.__name__
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "<mount_name>".
path = path_params.pop("path")
- host, remaining_params = replace_params(
- self.host_format, self.param_convertors, path_params
- )
+ host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
if not remaining_params:
return URLPath(path=path, host=host)
elif self.name is None or name.startswith(self.name + ":"):
else:
# 'name' matches "<mount_name>:<child_name>".
remaining_name = name[len(self.name) + 1 :]
- host, remaining_params = replace_params(
- self.host_format, self.param_convertors, path_params
- )
+ host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
for route in self.routes or []:
try:
url = route.url_path_for(remaining_name, **remaining_params)
await self.app(scope, receive, send)
def __eq__(self, other: typing.Any) -> bool:
- return (
- isinstance(other, Host)
- and self.host == other.host
- and self.app == other.app
- )
+ return isinstance(other, Host) and self.host == other.host and self.app == other.app
def __repr__(self) -> str:
class_name = self.__class__.__name__
def _wrap_gen_lifespan_context(
- lifespan_context: typing.Callable[
- [typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]
- ],
+ lifespan_context: typing.Callable[[typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]],
) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]:
cmgr = contextlib.contextmanager(lifespan_context)
async with self.lifespan_context(app) as maybe_state:
if maybe_state is not None:
if "state" not in scope:
- raise RuntimeError(
- 'The server does not support "state" in the lifespan scope.'
- )
+ raise RuntimeError('The server does not support "state" in the lifespan scope.')
scope["state"].update(maybe_state)
await send({"type": "lifespan.startup.complete"})
started = True
def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, Router) and self.routes == other.routes
- def mount(
- self, path: str, app: ASGIApp, name: str | None = None
- ) -> None: # pragma: nocover
+ def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: nocover
route = Mount(path, app=app, name=name)
self.routes.append(route)
- def host(
- self, host: str, app: ASGIApp, name: str | None = None
- ) -> None: # pragma: no cover
+ def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
route = Host(host, app=app, name=name)
self.routes.append(route)
"""
warnings.warn(
"The `route` decorator is deprecated, and will be removed in version 1.0.0."
- "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.", # noqa: E501
+ "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.",
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_route(
path,
func,
>>> app = Starlette(routes=routes)
"""
warnings.warn(
- "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to " # noqa: E501
- "https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501
+ "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to "
+ "https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_websocket_route(path, func, name=name)
return func
return decorator
- def add_event_handler(
- self, event_type: str, func: typing.Callable[[], typing.Any]
- ) -> None: # pragma: no cover
+ def add_event_handler(self, event_type: str, func: typing.Callable[[], typing.Any]) -> None: # pragma: no cover
assert event_type in ("startup", "shutdown")
if event_type == "startup":
def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
warnings.warn(
- "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
+ "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. "
"Refer to https://www.starlette.io/lifespan/ for recommended approach.",
DeprecationWarning,
)
- def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
+ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
self.add_event_handler(event_type, func)
return func
def render(self, content: typing.Any) -> bytes:
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
- assert isinstance(
- content, dict
- ), "The schema passed to OpenAPIResponse should be a dictionary."
+ assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary."
return yaml.dump(content, default_flow_style=False).encode("utf-8")
for method in route.methods or ["GET"]:
if method == "HEAD":
continue
- endpoints_info.append(
- EndpointInfo(path, method.lower(), route.endpoint)
- )
+ endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint))
else:
path = self._remove_converter(route.path)
for method in ["get", "post", "put", "patch", "delete", "options"]:
"""
return re.sub(r":\w+}", "}", path)
- def parse_docstring(
- self, func_or_method: typing.Callable[..., typing.Any]
- ) -> dict[str, typing.Any]:
+ def parse_docstring(self, func_or_method: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]:
"""
Given a function, parse the docstring as YAML and return a dictionary of info.
"""
def __init__(self, headers: Headers):
super().__init__(
status_code=304,
- headers={
- name: value
- for name, value in headers.items()
- if name in self.NOT_MODIFIED_HEADERS
- },
+ headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS},
)
spec = importlib.util.find_spec(package)
assert spec is not None, f"Package {package!r} could not be found."
assert spec.origin is not None, f"Package {package!r} could not be found."
- package_directory = os.path.normpath(
- os.path.join(spec.origin, "..", statics_dir)
- )
+ package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir))
assert os.path.isdir(
package_directory
), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
with OS specific path separators, and any '..', '.' components removed.
"""
route_path = get_route_path(scope)
- return os.path.normpath(os.path.join(*route_path.split("/"))) # noqa: E501
+ return os.path.normpath(os.path.join(*route_path.split("/")))
async def get_response(self, path: str, scope: Scope) -> Response:
"""
raise HTTPException(status_code=405)
try:
- full_path, stat_result = await anyio.to_thread.run_sync(
- self.lookup_path, path
- )
+ full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path)
except PermissionError:
raise HTTPException(status_code=401)
except OSError as exc:
# We're in HTML mode, and have got a directory URL.
# Check if we have 'index.html' file to serve.
index_path = os.path.join(path, "index.html")
- full_path, stat_result = await anyio.to_thread.run_sync(
- self.lookup_path, index_path
- )
+ full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path)
if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
if not scope["path"].endswith("/"):
# Directory URLs should redirect to always end in "/".
if self.html:
# Check for '404.html' if we're in HTML mode.
- full_path, stat_result = await anyio.to_thread.run_sync(
- self.lookup_path, "404.html"
- )
+ full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html")
if stat_result and stat.S_ISREG(stat_result.st_mode):
return FileResponse(full_path, stat_result=stat_result, status_code=404)
raise HTTPException(status_code=404)
) -> Response:
request_headers = Headers(scope=scope)
- response = FileResponse(
- full_path, status_code=status_code, stat_result=stat_result
- )
+ response = FileResponse(full_path, status_code=status_code, stat_result=stat_result)
if self.is_not_modified(response.headers, request_headers):
return NotModifiedResponse(response.headers)
return response
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
except FileNotFoundError:
- raise RuntimeError(
- f"StaticFiles directory '{self.directory}' does not exist."
- )
+ raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.")
if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
- raise RuntimeError(
- f"StaticFiles path '{self.directory}' is not a directory."
- )
+ raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.")
- def is_not_modified(
- self, response_headers: Headers, request_headers: Headers
- ) -> bool:
+ def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool:
"""
Given the request and response headers, return `True` if an HTTP
"Not Modified" response could be returned instead.
try:
if_modified_since = parsedate(request_headers["if-modified-since"])
last_modified = parsedate(response_headers["last-modified"])
- if (
- if_modified_since is not None
- and last_modified is not None
- and if_modified_since >= last_modified
- ):
+ if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified:
return True
except KeyError:
pass
self,
directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]],
*,
- context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
- | None = None,
+ context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
**env_options: typing.Any,
) -> None: ...
self,
*,
env: jinja2.Environment,
- context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
- | None = None,
+ context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
) -> None: ...
def __init__(
self,
- directory: str
- | PathLike[str]
- | typing.Sequence[str | PathLike[str]]
- | None = None,
+ directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]] | None = None,
*,
- context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]]
- | None = None,
+ context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None,
env: jinja2.Environment | None = None,
**env_options: typing.Any,
) -> None:
if env_options:
warnings.warn(
- "Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.", # noqa: E501
+ "Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.",
DeprecationWarning,
)
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
- assert bool(directory) ^ bool(
- env
- ), "either 'directory' or 'env' arguments must be passed"
+ assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed"
self.context_processors = context_processors or []
if directory is not None:
self.env = self._create_env(directory, **env_options)
# Deprecated usage
...
- def TemplateResponse(
- self, *args: typing.Any, **kwargs: typing.Any
- ) -> _TemplateResponse:
+ def TemplateResponse(self, *args: typing.Any, **kwargs: typing.Any) -> _TemplateResponse:
if args:
- if isinstance(
- args[0], str
- ): # the first argument is template name (old style)
+ if isinstance(args[0], str): # the first argument is template name (old style)
warnings.warn(
"The `name` is not the first parameter anymore. "
"The first parameter should be the `Request` instance.\n"
- 'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.', # noqa: E501
+ 'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.',
DeprecationWarning,
)
name = args[0]
context = args[1] if len(args) > 1 else kwargs.get("context", {})
- status_code = (
- args[2] if len(args) > 2 else kwargs.get("status_code", 200)
- )
+ status_code = args[2] if len(args) > 2 else kwargs.get("status_code", 200)
headers = args[2] if len(args) > 2 else kwargs.get("headers")
media_type = args[3] if len(args) > 3 else kwargs.get("media_type")
background = args[4] if len(args) > 4 else kwargs.get("background")
request = args[0]
name = args[1] if len(args) > 1 else kwargs["name"]
context = args[2] if len(args) > 2 else kwargs.get("context", {})
- status_code = (
- args[3] if len(args) > 3 else kwargs.get("status_code", 200)
- )
+ status_code = args[3] if len(args) > 3 else kwargs.get("status_code", 200)
headers = args[4] if len(args) > 4 else kwargs.get("headers")
media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
background = args[6] if len(args) > 6 else kwargs.get("background")
if "request" not in kwargs:
warnings.warn(
"The `TemplateResponse` now requires the `request` argument.\n"
- 'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.', # noqa: E501
+ 'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.',
DeprecationWarning,
)
if "request" not in kwargs.get("context", {}):
"You can install this with:\n"
" $ pip install httpx\n"
)
-_PortalFactoryType = typing.Callable[
- [], typing.ContextManager[anyio.abc.BlockingPortal]
-]
+_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]]
ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
ASGI2App = typing.Callable[[Scope], ASGIInstance]
def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
- raise WebSocketDisconnect(
- code=message.get("code", 1000), reason=message.get("reason", "")
- )
+ raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", ""))
elif message["type"] == "websocket.http.response.start":
status_code: int = message["status"]
headers: list[tuple[bytes, bytes]] = message["headers"]
def send_bytes(self, data: bytes) -> None:
self.send({"type": "websocket.receive", "bytes": data})
- def send_json(
- self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text"
- ) -> None:
+ def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None:
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
if mode == "text":
self.send({"type": "websocket.receive", "text": text})
self._raise_on_close(message)
return typing.cast(bytes, message["bytes"])
- def receive_json(
- self, mode: typing.Literal["text", "binary"] = "text"
- ) -> typing.Any:
+ def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any:
message = self.receive()
self._raise_on_close(message)
if mode == "text":
headers = [(b"host", (f"{host}:{port}").encode())]
# Include other request headers.
- headers += [
- (key.lower().encode(), value.encode())
- for key, value in request.headers.multi_items()
- ]
+ headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()]
scope: dict[str, typing.Any]
nonlocal raw_kwargs, response_started, template, context
if message["type"] == "http.response.start":
- assert (
- not response_started
- ), 'Received multiple "http.response.start" messages.'
+ assert not response_started, 'Received multiple "http.response.start" messages.'
raw_kwargs["status_code"] = message["status"]
- raw_kwargs["headers"] = [
- (key.decode(), value.decode())
- for key, value in message.get("headers", [])
- ]
+ raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])]
response_started = True
elif message["type"] == "http.response.body":
- assert (
- response_started
- ), 'Received "http.response.body" without "http.response.start".'
- assert (
- not response_complete.is_set()
- ), 'Received "http.response.body" after response completed.'
+ assert response_started, 'Received "http.response.body" without "http.response.start".'
+ assert not response_complete.is_set(), 'Received "http.response.body" after response completed.'
body = message.get("body", b"")
more_body = message.get("more_body", False)
if request.method != "HEAD":
headers: dict[str, str] | None = None,
follow_redirects: bool = True,
) -> None:
- self.async_backend = _AsyncBackend(
- backend=backend, backend_options=backend_options or {}
- )
+ self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {})
if _is_asgi3(app):
asgi_app = app
else:
if self.portal is not None:
yield self.portal
else:
- with anyio.from_thread.start_blocking_portal(
- **self.async_backend
- ) as portal:
+ with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal:
yield portal
def _choose_redirect_arg(
self, follow_redirects: bool | None, allow_redirects: bool | None
) -> bool | httpx._client.UseClientDefault:
- redirect: bool | httpx._client.UseClientDefault = (
- httpx._client.USE_CLIENT_DEFAULT
- )
+ redirect: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT
if allow_redirects is not None:
- message = (
- "The `allow_redirects` argument is deprecated. "
- "Use `follow_redirects` instead."
- )
+ message = "The `allow_redirects` argument is deprecated. Use `follow_redirects` instead."
warnings.warn(message, DeprecationWarning)
redirect = allow_redirects
if follow_redirects is not None:
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
url = self._merge_url(url)
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
params: httpx._types.QueryParamTypes | None = None,
headers: httpx._types.HeaderTypes | None = None,
cookies: httpx._types.CookieTypes | None = None,
- auth: httpx._types.AuthTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
follow_redirects: bool | None = None,
allow_redirects: bool | None = None,
- timeout: httpx._types.TimeoutTypes
- | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
+ timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
extensions: dict[str, typing.Any] | None = None,
) -> httpx.Response:
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
def __enter__(self) -> TestClient:
with contextlib.ExitStack() as stack:
- self.portal = portal = stack.enter_context(
- anyio.from_thread.start_blocking_portal(**self.async_backend)
- )
+ self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend))
@stack.callback
def reset_portal() -> None:
ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]]
-StatefulLifespan = typing.Callable[
- [AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]
-]
+StatefulLifespan = typing.Callable[[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]]
Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]
-HTTPExceptionHandler = typing.Callable[
- ["Request", Exception], "Response | typing.Awaitable[Response]"
-]
-WebSocketExceptionHandler = typing.Callable[
- ["WebSocket", Exception], typing.Awaitable[None]
-]
+HTTPExceptionHandler = typing.Callable[["Request", Exception], "Response | typing.Awaitable[Response]"]
+WebSocketExceptionHandler = typing.Callable[["WebSocket", Exception], typing.Awaitable[None]]
ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler]
message = await self._receive()
message_type = message["type"]
if message_type != "websocket.connect":
- raise RuntimeError(
- 'Expected ASGI message "websocket.connect", '
- f"but got {message_type!r}"
- )
+ raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
self.client_state = WebSocketState.CONNECTED
return message
elif self.client_state == WebSocketState.CONNECTED:
message_type = message["type"]
if message_type not in {"websocket.receive", "websocket.disconnect"}:
raise RuntimeError(
- 'Expected ASGI message "websocket.receive" or '
- f'"websocket.disconnect", but got {message_type!r}'
+ f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
)
if message_type == "websocket.disconnect":
self.client_state = WebSocketState.DISCONNECTED
return message
else:
- raise RuntimeError(
- 'Cannot call "receive" once a disconnect message has been received.'
- )
+ raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
async def send(self, message: Message) -> None:
"""
"""
if self.application_state == WebSocketState.CONNECTING:
message_type = message["type"]
- if message_type not in {
- "websocket.accept",
- "websocket.close",
- "websocket.http.response.start",
- }:
+ if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
raise RuntimeError(
- 'Expected ASGI message "websocket.accept",'
- '"websocket.close" or "websocket.http.response.start",'
+ 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
f"but got {message_type!r}"
)
if message_type == "websocket.close":
message_type = message["type"]
if message_type not in {"websocket.send", "websocket.close"}:
raise RuntimeError(
- 'Expected ASGI message "websocket.send" or "websocket.close", '
- f"but got {message_type!r}"
+ f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
elif self.application_state == WebSocketState.RESPONSE:
message_type = message["type"]
if message_type != "websocket.http.response.body":
- raise RuntimeError(
- 'Expected ASGI message "websocket.http.response.body", '
- f"but got {message_type!r}"
- )
+ raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
if not message.get("more_body", False):
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
if self.client_state == WebSocketState.CONNECTING:
# If we haven't yet seen the 'connect' message, then wait for it first.
await self.receive()
- await self.send(
- {"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
- )
+ await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
def _raise_on_disconnect(self, message: Message) -> None:
if message["type"] == "websocket.disconnect":
async def receive_text(self) -> str:
if self.application_state != WebSocketState.CONNECTED:
- raise RuntimeError(
- 'WebSocket is not connected. Need to call "accept" first.'
- )
+ raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
return typing.cast(str, message["text"])
async def receive_bytes(self) -> bytes:
if self.application_state != WebSocketState.CONNECTED:
- raise RuntimeError(
- 'WebSocket is not connected. Need to call "accept" first.'
- )
+ raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
return typing.cast(bytes, message["bytes"])
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
if self.application_state != WebSocketState.CONNECTED:
- raise RuntimeError(
- 'WebSocket is not connected. Need to call "accept" first.'
- )
+ raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
message = await self.receive()
self._raise_on_disconnect(message)
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
async def close(self, code: int = 1000, reason: str | None = None) -> None:
- await self.send(
- {"type": "websocket.close", "code": code, "reason": reason or ""}
- )
+ await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
async def send_denial_response(self, response: Response) -> None:
if "websocket.http.response" in self.scope.get("extensions", {}):
await response(self.scope, self.receive, self.send)
else:
- raise RuntimeError(
- "The server doesn't support the Websocket Denial Response extension."
- )
+ raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
class WebSocketClose:
self.reason = reason or ""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- await send(
- {"type": "websocket.close", "code": self.code, "reason": self.reason}
- )
+ await send({"type": "websocket.close", "code": self.code, "reason": self.reason})
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage")
- app = Starlette(
- routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]
- )
+ app = Starlette(routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)])
client = test_client_factory(app)
response = client.get("/")
ctxvar.set("set by endpoint")
return PlainTextResponse("Homepage")
- app = Starlette(
- middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]
- )
+ app = Starlette(middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)])
client = test_client_factory(app)
response = client.get("/")
events.append("Background task finished")
async def endpoint_with_background_task(_: Request) -> PlainTextResponse:
- return PlainTextResponse(
- content="Hello", background=BackgroundTask(sleep_and_set)
- )
+ return PlainTextResponse(content="Hello", background=BackgroundTask(sleep_and_set))
- async def passthrough(
- request: Request, call_next: RequestResponseEndpoint
- ) -> Response:
+ async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response:
return await call_next(request)
app = Starlette(
}
)
- pytest.fail(
- "http.disconnect should have been received and canceled the scope"
- ) # pragma: no cover
+ pytest.fail("http.disconnect should have been received and canceled the scope") # pragma: no cover
app = DiscardingMiddleware(downstream_app)
await rcv_stream.aclose()
-def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( # noqa: E501
+def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(
test_client_factory: TestClientFactory,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
- assert (
- await request.body() == b"a"
- ) # this buffers the request body in memory
+ assert await request.body() == b"a" # this buffers the request body in memory
resp = await call_next(request)
async for chunk in request.stream():
if chunk:
assert response.status_code == 200
-def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( # noqa: E501
+def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next(
test_client_factory: TestClientFactory,
) -> None:
async def homepage(request: Request) -> PlainTextResponse:
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
- assert (
- await request.body() == b"a"
- ) # this buffers the request body in memory
+ assert await request.body() == b"a" # this buffers the request body in memory
resp = await call_next(request)
assert await request.body() == b"a" # no problem here
return resp
self.events = events
super().__init__(app)
- async def dispatch(
- self, request: Request, call_next: RequestResponseEndpoint
- ) -> Response:
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
self.events.append(f"{self.version}:STARTED")
res = await call_next(request)
self.events.append(f"{self.version}:COMPLETED")
app = Starlette(
routes=[Route("/", sleepy)],
- middleware=[
- Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)
- ],
+ middleware=[Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)],
)
scope = {
await Response(b"good!")(scope, receive, send)
class MyMiddleware(BaseHTTPMiddleware):
- async def dispatch(
- self, request: Request, call_next: RequestResponseEndpoint
- ) -> Response:
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
return await call_next(request)
app = MyMiddleware(app_poll_disconnect)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
- middleware=[
- Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
- ],
+ middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])],
)
client = test_client_factory(app)
methods=["delete", "get", "head", "options", "patch", "post", "put"],
)
],
- middleware=[
- Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
- ],
+ middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])],
)
client = test_client_factory(app)
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
- assert (
- response.headers["access-control-allow-origin"]
- == "https://subdomain.example.org"
- )
+ assert response.headers["access-control-allow-origin"] == "https://subdomain.example.org"
assert "access-control-allow-credentials" not in response.headers
# Test diallowed standard response
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
- return PlainTextResponse(
- "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
- )
+ return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"})
app = Starlette(
routes=[Route("/", endpoint=homepage)],
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
- return PlainTextResponse(
- "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
- )
+ return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"})
app = Starlette(
routes=[Route("/", endpoint=homepage)],
)
client = test_client_factory(app)
- response = client.get(
- "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}
- )
+ response = client.get("/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"})
assert response.status_code == 200
assert response.headers["vary"] == "Accept-Encoding, Origin"
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
- return PlainTextResponse(
- "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
- )
+ return PlainTextResponse("Homepage", status_code=200, headers={"Vary": "Accept-Encoding"})
app = Starlette(
routes=[
assert response.headers["access-control-allow-origin"] == "*"
assert "access-control-allow-credentials" not in response.headers
- response = client.get(
- "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}
- )
+ response = client.get("/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"})
assert response.headers["access-control-allow-origin"] == "https://someplace.org"
assert "access-control-allow-credentials" not in response.headers
yield bytes
streaming = generator(bytes=b"x" * 400, count=10)
- return StreamingResponse(
- streaming, status_code=200, headers={"Content-Encoding": "text"}
- )
+ return StreamingResponse(streaming, status_code=200, headers={"Content-Encoding": "text"})
app = Starlette(
routes=[Route("/", endpoint=homepage)],
Route("/update_session", endpoint=update_session, methods=["POST"]),
Route("/clear_session", endpoint=clear_session, methods=["POST"]),
],
- middleware=[
- Middleware(SessionMiddleware, secret_key="example", https_only=True)
- ],
+ middleware=[Middleware(SessionMiddleware, secret_key="example", https_only=True)],
)
secure_client = test_client_factory(app, base_url="https://testserver")
unsecure_client = test_client_factory(app, base_url="http://testserver")
routes=[
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
- middleware=[
- Middleware(SessionMiddleware, secret_key="example", path="/second_app")
- ],
+ middleware=[Middleware(SessionMiddleware, secret_key="example", path="/second_app")],
)
app = Starlette(routes=[Mount("/second_app", app=second_app)])
client = test_client_factory(app, base_url="http://testserver/second_app")
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
- middleware=[
- Middleware(SessionMiddleware, secret_key="example", domain=".example.com")
- ],
+ middleware=[Middleware(SessionMiddleware, secret_key="example", domain=".example.com")],
)
client: TestClient = test_client_factory(app)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
- middleware=[
- Middleware(
- TrustedHostMiddleware, allowed_hosts=["testserver", "*.testserver"]
- )
- ],
+ middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.testserver"])],
)
client = test_client_factory(app)
app = Starlette(
routes=[Route("/", endpoint=homepage)],
- middleware=[
- Middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"])
- ],
+ middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"])],
)
client = test_client_factory(app, base_url="https://example.com")
CustomWSException: custom_ws_exception_handler,
}
-middleware = [
- Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.example.org"])
-]
+middleware = [Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.example.org"])]
app = Starlette(
routes=[
nonlocal cleanup_complete
cleanup_complete = True
- with pytest.deprecated_call(
- match="The on_startup and on_shutdown parameters are deprecated"
- ):
+ with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
app = Starlette(
on_startup=[run_startup],
on_shutdown=[run_cleanup],
app = Starlette()
with pytest.deprecated_call(
- match=(
- "The `exception_handler` decorator is deprecated, "
- "and will be removed in version 1.0.0."
- )
+ match=("The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0.")
) as record:
app.exception_handler(500)(http_exception)
assert len(record) == 1
with pytest.deprecated_call(
- match=(
- "The `middleware` decorator is deprecated, "
- "and will be removed in version 1.0.0."
- )
+ match=("The `middleware` decorator is deprecated, and will be removed in version 1.0.0.")
) as record:
- async def middleware(
- request: Request, call_next: RequestResponseEndpoint
- ) -> None: ... # pragma: no cover
+ async def middleware(request: Request, call_next: RequestResponseEndpoint) -> None: ... # pragma: no cover
app.middleware("http")(middleware)
assert len(record) == 1
with pytest.deprecated_call(
- match=(
- "The `route` decorator is deprecated, "
- "and will be removed in version 1.0.0."
- )
+ match=("The `route` decorator is deprecated, and will be removed in version 1.0.0.")
) as record:
app.route("/")(async_homepage)
assert len(record) == 1
with pytest.deprecated_call(
- match=(
- "The `websocket_route` decorator is deprecated, "
- "and will be removed in version 1.0.0."
- )
+ match=("The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0.")
) as record:
app.websocket_route("/ws")(websocket_endpoint)
assert len(record) == 1
with pytest.deprecated_call(
- match=(
- "The `on_event` decorator is deprecated, "
- "and will be removed in version 1.0.0."
- )
+ match=("The `on_event` decorator is deprecated, and will be removed in version 1.0.0.")
) as record:
async def startup() -> None: ... # pragma: no cover
response = client.get("/dashboard/decorated")
assert response.status_code == 403
- response = client.get(
- "/dashboard/decorated/sync", auth=("tomchristie", "example")
- )
+ response = client.get("/dashboard/decorated/sync", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {
"authenticated": True,
pass # pragma: nocover
with pytest.raises(WebSocketDisconnect):
- with client.websocket_connect(
- "/ws", headers={"Authorization": "basic foobar"}
- ):
+ with client.websocket_connect("/ws", headers={"Authorization": "basic foobar"}):
pass # pragma: nocover
- with client.websocket_connect(
- "/ws", auth=("tomchristie", "example")
- ) as websocket:
+ with client.websocket_connect("/ws", auth=("tomchristie", "example")) as websocket:
data = websocket.receive_json()
assert data == {"authenticated": True, "user": "tomchristie"}
pass # pragma: nocover
with pytest.raises(WebSocketDisconnect):
- with client.websocket_connect(
- "/ws/decorated", headers={"Authorization": "basic foobar"}
- ):
+ with client.websocket_connect("/ws/decorated", headers={"Authorization": "basic foobar"}):
pass # pragma: nocover
- with client.websocket_connect(
- "/ws/decorated", auth=("tomchristie", "example")
- ) as websocket:
+ with client.websocket_connect("/ws/decorated", auth=("tomchristie", "example")) as websocket:
data = websocket.receive_json()
assert data == {
"authenticated": True,
with test_client_factory(app) as client:
response = client.get("/admin")
assert response.status_code == 200
- url = "{}?{}".format(
- "http://testserver/", urlencode({"next": "http://testserver/admin"})
- )
+ url = "{}?{}".format("http://testserver/", urlencode({"next": "http://testserver/admin"}))
assert response.url == url
response = client.get("/admin", auth=("tomchristie", "example"))
response = client.get("/admin/sync")
assert response.status_code == 200
- url = "{}?{}".format(
- "http://testserver/", urlencode({"next": "http://testserver/admin/sync"})
- )
+ url = "{}?{}".format("http://testserver/", urlencode({"next": "http://testserver/admin/sync"}))
assert response.url == url
response = client.get("/admin/sync", auth=("tomchristie", "example"))
other_app = Starlette(
routes=[Route("/control-panel", control_panel)],
- middleware=[
- Middleware(
- AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error
- )
- ],
+ middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error)],
)
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
- response = client.get(
- "/control-panel", headers={"Authorization": "basic foobar"}
- )
+ response = client.get("/control-panel", headers={"Authorization": "basic foobar"})
assert response.status_code == 401
assert response.json() == {"error": "Invalid basic auth credentials"}
tasks.add_task(increment, amount=1)
tasks.add_task(increment, amount=2)
tasks.add_task(increment, amount=3)
- response = Response(
- "tasks initiated", media_type="text/plain", background=tasks
- )
+ response = Response("tasks initiated", media_type="text/plain", background=tasks)
await response(scope, receive, send)
client = test_client_factory(app)
tasks = BackgroundTasks()
tasks.add_task(increment)
tasks.add_task(increment)
- response = Response(
- "tasks initiated", media_type="text/plain", background=tasks
- )
+ response = Response("tasks initiated", media_type="text/plain", background=tasks)
await response(scope, receive, send)
client = test_client_factory(app)
"""
We use `assert_type` to test the types returned by Config via mypy.
"""
- config = Config(
- environ={"STR": "some_str_value", "STR_CAST": "some_str_value", "BOOL": "true"}
- )
+ config = Config(environ={"STR": "some_str_value", "STR_CAST": "some_str_value", "BOOL": "true"})
assert_type(config("STR"), str)
assert_type(config("STR_DEFAULT", default=""), str)
def test_config_with_env_prefix(tmpdir: Path, monkeypatch: pytest.MonkeyPatch) -> None:
- config = Config(
- environ={"APP_DEBUG": "value", "ENVIRONMENT": "dev"}, env_prefix="APP_"
- )
+ config = Config(environ={"APP_DEBUG": "value", "ENVIRONMENT": "dev"}, env_prefix="APP_")
assert config.get("DEBUG") == "value"
with pytest.raises(KeyError):
)
-def test_datetime_convertor(
- test_client_factory: TestClientFactory, app: Router
-) -> None:
+def test_datetime_convertor(test_client_factory: TestClientFactory, app: Router) -> None:
client = test_client_factory(app)
response = client.get("/datetime/2020-01-01T00:00:00")
assert response.json() == {"datetime": "2020-01-01T00:00:00"}
assert (
- app.url_path_for("datetime-convertor", param=datetime(1996, 1, 22, 23, 0, 0))
- == "/datetime/1996-01-22T23:00:00"
+ app.url_path_for("datetime-convertor", param=datetime(1996, 1, 22, 23, 0, 0)) == "/datetime/1996-01-22T23:00:00"
)
@pytest.mark.parametrize("param, status_code", [("1.0", 200), ("1-0", 404)])
-def test_default_float_convertor(
- test_client_factory: TestClientFactory, param: str, status_code: int
-) -> None:
+def test_default_float_convertor(test_client_factory: TestClientFactory, param: str, status_code: int) -> None:
def float_convertor(request: Request) -> JSONResponse:
param = request.path_params["param"]
assert isinstance(param, float)
def test_url_from_scope() -> None:
- u = URL(
- scope={"path": "/path/to/somewhere", "query_string": b"abc=123", "headers": []}
- )
+ u = URL(scope={"path": "/path/to/somewhere", "query_string": b"abc=123", "headers": []})
assert u == "/path/to/somewhere?abc=123"
assert repr(u) == "URL('/path/to/somewhere?abc=123')"
assert dict(q) == {"a": "456", "b": "789"}
assert str(q) == "a=123&a=456&b=789"
assert repr(q) == "QueryParams('a=123&a=456&b=789')"
- assert QueryParams({"a": "123", "b": "456"}) == QueryParams(
- [("a", "123"), ("b", "456")]
- )
+ assert QueryParams({"a": "123", "b": "456"}) == QueryParams([("a", "123"), ("b", "456")])
assert QueryParams({"a": "123", "b": "456"}) == QueryParams("a=123&b=456")
- assert QueryParams({"a": "123", "b": "456"}) == QueryParams(
- {"b": "456", "a": "123"}
- )
+ assert QueryParams({"a": "123", "b": "456"}) == QueryParams({"b": "456", "a": "123"})
assert QueryParams() == QueryParams({})
assert QueryParams([("a", "123"), ("a", "456")]) == QueryParams("a=123&a=456")
assert QueryParams({"a": "123", "b": "456"}) != "invalid"
assert len(form) == 2
assert list(form) == ["a", "b"]
assert dict(form) == {"a": "456", "b": upload}
- assert (
- repr(form)
- == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])"
- )
+ assert repr(form) == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])"
assert FormData(form) == form
assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")])
assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"}
async def test_upload_file_repr_headers() -> None:
stream = io.BytesIO(b"data")
file = UploadFile(filename="file", file=stream, headers=Headers({"foo": "bar"}))
- assert (
- repr(file)
- == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))"
- )
+ assert repr(file) == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))"
def test_multidict() -> None:
assert dict(q) == {"a": "456", "b": "789"}
assert str(q) == "MultiDict([('a', '123'), ('a', '456'), ('b', '789')])"
assert repr(q) == "MultiDict([('a', '123'), ('a', '456'), ('b', '789')])"
- assert MultiDict({"a": "123", "b": "456"}) == MultiDict(
- [("a", "123"), ("b", "456")]
- )
+ assert MultiDict({"a": "123", "b": "456"}) == MultiDict([("a", "123"), ("b", "456")])
assert MultiDict({"a": "123", "b": "456"}) == MultiDict({"b": "456", "a": "123"})
assert MultiDict() == MultiDict({})
assert MultiDict({"a": "123", "b": "456"}) != "invalid"
return PlainTextResponse(f"Hello, {username}!")
-app = Router(
- routes=[Route("/", endpoint=Homepage), Route("/{username}", endpoint=Homepage)]
-)
+app = Router(routes=[Route("/", endpoint=Homepage), Route("/{username}", endpoint=Homepage)])
@pytest.fixture
raise BadBodyException(422)
-async def handler_that_reads_body(
- request: Request, exc: BadBodyException
-) -> JSONResponse:
+async def handler_that_reads_body(request: Request, exc: BadBodyException) -> JSONResponse:
body = await request.body()
return JSONResponse(status_code=422, content={"body": body.decode()})
def test_http_repr() -> None:
- assert repr(HTTPException(404)) == (
- "HTTPException(status_code=404, detail='Not Found')"
- )
+ assert repr(HTTPException(404)) == ("HTTPException(status_code=404, detail='Not Found')")
assert repr(HTTPException(404, detail="Not Found: foo")) == (
"HTTPException(status_code=404, detail='Not Found: foo')"
)
return app
-def test_multipart_request_data(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART)
assert response.json() == {"some": "data"}
-def test_multipart_request_files(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "test.txt")
with open(path, "wb") as file:
file.write(b"<file content>")
}
-def test_multipart_request_files_with_content_type(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_files_with_content_type(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "test.txt")
with open(path, "wb") as file:
file.write(b"<file content>")
}
-def test_multipart_request_multiple_files(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_multiple_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path1 = os.path.join(tmpdir, "test1.txt")
with open(path1, "wb") as file:
file.write(b"<file1 content>")
client = test_client_factory(app)
with open(path1, "rb") as f1, open(path2, "rb") as f2:
- response = client.post(
- "/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")}
- )
+ response = client.post("/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")})
assert response.json() == {
"test1": {
"filename": "test1.txt",
}
-def test_multipart_request_multiple_files_with_headers(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_multiple_files_with_headers(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path1 = os.path.join(tmpdir, "test1.txt")
with open(path1, "wb") as file:
file.write(b"<file1 content>")
}
-def test_multipart_request_mixed_files_and_data(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post(
"/",
b"value1\r\n"
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n"
),
- headers={
- "Content-Type": (
- "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c"
- )
- },
+ headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
)
assert response.json() == {
"file": {
}
-def test_multipart_request_with_charset_for_filename(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post(
"/",
data=(
# file
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore
- b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # noqa: E501
+ b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n'
b"Content-Type: text/plain\r\n\r\n"
b"<file content>\r\n"
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n"
),
- headers={
- "Content-Type": (
- "multipart/form-data; charset=utf-8; "
- "boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c"
- )
- },
+ headers={"Content-Type": ("multipart/form-data; charset=utf-8; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
)
assert response.json() == {
"file": {
}
-def test_multipart_request_without_charset_for_filename(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_without_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post(
"/",
data=(
# file
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" # type: ignore
- b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n' # noqa: E501
+ b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n'
b"Content-Type: image/jpeg\r\n\r\n"
b"<file content>\r\n"
b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n"
),
- headers={
- "Content-Type": (
- "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c"
- )
- },
+ headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
)
assert response.json() == {
"file": {
}
-def test_multipart_request_with_encoded_value(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_request_with_encoded_value(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post(
"/",
b"Transf\xc3\xa9rer\r\n"
b"--20b303e711c4ab8c443184ac833ab00f--\r\n"
),
- headers={
- "Content-Type": (
- "multipart/form-data; charset=utf-8; "
- "boundary=20b303e711c4ab8c443184ac833ab00f"
- )
- },
+ headers={"Content-Type": ("multipart/form-data; charset=utf-8; boundary=20b303e711c4ab8c443184ac833ab00f")},
)
assert response.json() == {"value": "Transférer"}
-def test_urlencoded_request_data(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_urlencoded_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post("/", data={"some": "data"})
assert response.json() == {"some": "data"}
assert response.json() == {}
-def test_urlencoded_percent_encoding(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_urlencoded_percent_encoding(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post("/", data={"some": "da ta"})
assert response.json() == {"some": "da ta"}
-def test_urlencoded_percent_encoding_keys(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_urlencoded_percent_encoding_keys(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post("/", data={"so me": "data"})
assert response.json() == {"so me": "data"}
-def test_urlencoded_multi_field_app_reads_body(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_urlencoded_multi_field_app_reads_body(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app_read_body)
response = client.post("/", data={"some": "data", "second": "key pair"})
assert response.json() == {"some": "data", "second": "key pair"}
-def test_multipart_multi_field_app_reads_body(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_multipart_multi_field_app_reads_body(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app_read_body)
- response = client.post(
- "/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART
- )
+ response = client.post("/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART)
assert response.json() == {"some": "data", "second": "key pair"}
"/",
data=(
# file
- b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # type: ignore # noqa: E501
+ b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # type: ignore
b"Content-Type: text/plain\r\n\r\n"
b"<file content>\r\n"
),
b'Content-Disposition: form-data; ="field0"\r\n\r\n'
b"value0\r\n"
),
- headers={
- "Content-Type": (
- "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c"
- )
- },
+ headers={"Content-Type": ("multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c")},
)
assert res.status_code == 400
- assert (
- res.text == 'The Content-Disposition header field "name" must be provided.'
- )
+ assert res.text == 'The Content-Disposition header field "name" must be provided.'
@pytest.mark.parametrize(
client = test_client_factory(app)
fields = []
for i in range(1001):
- fields.append(
- "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n"
- )
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
client = test_client_factory(app)
fields = []
for i in range(1001):
- fields.append(
- "--B\r\n"
- f'Content-Disposition: form-data; name="N{i}"; filename="F{i}";\r\n\r\n'
- "\r\n"
- )
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}"; filename="F{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
for i in range(1001):
# This uses the same field name "N" for all files, equivalent to a
# multifile upload form field
- fields.append(
- "--B\r\n"
- f'Content-Disposition: form-data; name="N"; filename="F{i}";\r\n\r\n'
- "\r\n"
- )
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="N"; filename="F{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
client = test_client_factory(app)
fields = []
for i in range(1001):
- fields.append(
- "--B\r\n"
- f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n'
- "\r\n"
- )
- fields.append(
- "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n"
- )
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n")
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
client = test_client_factory(app)
fields = []
for i in range(2):
- fields.append(
- "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n"
- )
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
client = test_client_factory(app)
fields = []
for i in range(2):
- fields.append(
- "--B\r\n"
- f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n'
- "\r\n"
- )
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
with expectation:
res = client.post(
client = test_client_factory(make_app_max_parts(max_fields=2000, max_files=2000))
fields = []
for i in range(2000):
- fields.append(
- "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n"
- )
- fields.append(
- "--B\r\n"
- f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n'
- "\r\n"
- )
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n")
+ fields.append("--B\r\n" f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' "\r\n")
data = "".join(fields).encode("utf-8")
data += b"--B--\r\n"
res = client.post(
cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9)
generator = numbers(1, 5)
- response = StreamingResponse(
- generator, media_type="text/plain", background=cleanup_task
- )
+ response = StreamingResponse(generator, media_type="text/plain", background=cleanup_task)
await response(scope, receive, send)
assert filled_by_bg_task == ""
cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
- response = FileResponse(
- path=path, filename="example.png", background=cleanup_task
- )
+ response = FileResponse(path=path, filename="example.png", background=cleanup_task)
await response(scope, receive, send)
assert filled_by_bg_task == ""
await app({"type": "http", "method": "head"}, receive, send)
-def test_file_response_set_media_type(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_set_media_type(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "xyz"
path.write_bytes(b"<file content>")
assert response.headers["content-type"] == "image/jpeg"
-def test_file_response_with_directory_raises_error(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_with_directory_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
app = FileResponse(path=tmp_path, filename="example.png")
client = test_client_factory(app)
with pytest.raises(RuntimeError) as exc_info:
assert "is not a file" in str(exc_info.value)
-def test_file_response_with_missing_file_raises_error(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_with_missing_file_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "404.txt"
app = FileResponse(path=path, filename="404.txt")
client = test_client_factory(app)
assert "does not exist" in str(exc_info.value)
-def test_file_response_with_chinese_filename(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_with_chinese_filename(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
content = b"file content"
filename = "ä½ å¥½.txt" # probably "Hello.txt" in Chinese
path = tmp_path / filename
assert response.headers["content-disposition"] == expected_disposition
-def test_file_response_with_inline_disposition(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_with_inline_disposition(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
content = b"file content"
filename = "hello.txt"
path = tmp_path / filename
FileResponse(path=tmp_path, filename="example.png", method="GET")
-def test_set_cookie(
- test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch
-) -> None:
+def test_set_cookie(test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch) -> None:
# Mock time used as a reference for `Expires` by stdlib `SimpleCookie`.
mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc)
monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp())
response = client.get("/")
assert response.text == "Hello, world!"
assert (
- response.headers["set-cookie"]
- == "mycookie=myvalue; Domain=localhost; expires=Thu, 22 Jan 2037 12:00:10 GMT; "
+ response.headers["set-cookie"] == "mycookie=myvalue; Domain=localhost; expires=Thu, 22 Jan 2037 12:00:10 GMT; "
"HttpOnly; Max-Age=10; Path=/; SameSite=none; Secure"
)
@pytest.mark.parametrize(
"expires",
[
- pytest.param(
- dt.datetime(2037, 1, 22, 12, 0, 10, tzinfo=dt.timezone.utc), id="datetime"
- ),
+ pytest.param(dt.datetime(2037, 1, 22, 12, 0, 10, tzinfo=dt.timezone.utc), id="datetime"),
pytest.param("Thu, 22 Jan 2037 12:00:10 GMT", id="str"),
pytest.param(10, id="int"),
],
assert response.headers["content-type"] == "text/plain; charset=utf-8"
-def test_file_response_known_size(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_file_response_known_size(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "xyz"
content = b"<file content>" * 1000
path.write_bytes(content)
def test_streaming_response_known_size(test_client_factory: TestClientFactory) -> None:
- app = StreamingResponse(
- content=iter(["hello", "world"]), headers={"content-length": "10"}
- )
+ app = StreamingResponse(content=iter(["hello", "world"]), headers={"content-length": "10"})
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.headers["content-length"] == "10"
response = client.get("/path-with-parentheses(7)")
assert response.status_code == 200
assert response.json() == {"int": 7}
- assert (
- app.url_path_for("path-with-parentheses", param=7)
- == "/path-with-parentheses(7)"
- )
+ assert app.url_path_for("path-with-parentheses", param=7) == "/path-with-parentheses(7)"
# Test float conversion
response = client.get("/float/25.5")
response = client.get("/path/some/example")
assert response.status_code == 200
assert response.json() == {"path": "some/example"}
- assert (
- app.url_path_for("path-convertor", param="some/example") == "/path/some/example"
- )
+ assert app.url_path_for("path-convertor", param="some/example") == "/path/some/example"
# Test UUID conversion
response = client.get("/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a")
assert response.status_code == 200
assert response.json() == {"uuid": "ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"}
assert (
- app.url_path_for(
- "uuid-convertor", param=uuid.UUID("ec38df32-ceda-4cfa-9b4a-1aeb94ad551a")
- )
+ app.url_path_for("uuid-convertor", param=uuid.UUID("ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"))
== "/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"
)
assert app.url_path_for("homepage") == "/"
assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie"
assert app.url_path_for("websocket_endpoint") == "/ws"
- with pytest.raises(
- NoMatchFound, match='No route exists for name "broken" and params "".'
- ):
+ with pytest.raises(NoMatchFound, match='No route exists for name "broken" and params "".'):
assert app.url_path_for("broken")
- with pytest.raises(
- NoMatchFound, match='No route exists for name "broken" and params "key, key2".'
- ):
+ with pytest.raises(NoMatchFound, match='No route exists for name "broken" and params "key, key2".'):
assert app.url_path_for("broken", key="value", key2="value2")
with pytest.raises(AssertionError):
app.url_path_for("user", username="tom/christie")
def test_url_for() -> None:
+ assert app.url_path_for("homepage").make_absolute_url(base_url="https://example.org") == "https://example.org/"
assert (
- app.url_path_for("homepage").make_absolute_url(base_url="https://example.org")
- == "https://example.org/"
- )
- assert (
- app.url_path_for("homepage").make_absolute_url(
- base_url="https://example.org/root_path/"
- )
+ app.url_path_for("homepage").make_absolute_url(base_url="https://example.org/root_path/")
== "https://example.org/root_path/"
)
assert (
- app.url_path_for("user", username="tomchristie").make_absolute_url(
- base_url="https://example.org"
- )
+ app.url_path_for("user", username="tomchristie").make_absolute_url(base_url="https://example.org")
== "https://example.org/users/tomchristie"
)
assert (
- app.url_path_for("user", username="tomchristie").make_absolute_url(
- base_url="https://example.org/root_path/"
- )
+ app.url_path_for("user", username="tomchristie").make_absolute_url(base_url="https://example.org/root_path/")
== "https://example.org/root_path/users/tomchristie"
)
assert (
- app.url_path_for("websocket_endpoint").make_absolute_url(
- base_url="https://example.org"
- )
+ app.url_path_for("websocket_endpoint").make_absolute_url(base_url="https://example.org")
== "wss://example.org/ws"
)
users = Router([Route("/{username}", ok, name="user")])
mounted = Router([Mount("/{subpath}/users", users, name="users")])
- assert (
- mounted.url_path_for("users:user", subpath="test", username="tom")
- == "/test/users/tom"
- )
- assert (
- mounted.url_path_for("users", subpath="test", path="/tom") == "/test/users/tom"
- )
+ assert mounted.url_path_for("users:user", subpath="test", username="tom") == "/test/users/tom"
+ assert mounted.url_path_for("users", subpath="test", path="/tom") == "/test/users/tom"
def test_mount_at_root(test_client_factory: TestClientFactory) -> None:
response = client.get("/")
assert response.status_code == 200
- client = test_client_factory(
- mixed_hosts_app, base_url="https://port.example.org:3600/"
- )
+ client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:3600/")
response = client.get("/users")
assert response.status_code == 404
response = client.get("/")
assert response.status_code == 200
- client = test_client_factory(
- mixed_hosts_app, base_url="https://port.example.org:5600/"
- )
+ client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:5600/")
response = client.get("/")
assert response.status_code == 200
def test_host_reverse_urls() -> None:
+ assert mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever") == "https://www.example.org/"
assert (
- mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever")
- == "https://www.example.org/"
- )
- assert (
- mixed_hosts_app.url_path_for("users").make_absolute_url("https://whatever")
- == "https://www.example.org/users"
+ mixed_hosts_app.url_path_for("users").make_absolute_url("https://whatever") == "https://www.example.org/users"
)
assert (
mixed_hosts_app.url_path_for("api:users").make_absolute_url("https://whatever")
== "https://api.example.org/users"
)
assert (
- mixed_hosts_app.url_path_for("port:homepage").make_absolute_url(
- "https://whatever"
- )
+ mixed_hosts_app.url_path_for("port:homepage").make_absolute_url("https://whatever")
== "https://port.example.org:3600/"
)
await response(scope, receive, send)
-subdomain_router = Router(
- routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")]
-)
+subdomain_router = Router(routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")])
def test_subdomain_routing(test_client_factory: TestClientFactory) -> None:
def test_subdomain_reverse_urls() -> None:
assert (
- subdomain_router.url_path_for(
- "subdomains", subdomain="foo", path="/homepage"
- ).make_absolute_url("https://whatever")
+ subdomain_router.url_path_for("subdomains", subdomain="foo", path="/homepage").make_absolute_url(
+ "https://whatever"
+ )
== "https://foo.example.org/homepage"
)
def test_url_for_with_root_path(test_client_factory: TestClientFactory) -> None:
app = Starlette(routes=echo_url_routes)
- client = test_client_factory(
- app, base_url="https://www.example.org/", root_path="/sub_path"
- )
+ client = test_client_factory(app, base_url="https://www.example.org/", root_path="/sub_path")
response = client.get("/sub_path/")
assert response.json() == {
"index": "https://www.example.org/sub_path/",
nonlocal shutdown_complete
shutdown_complete = True
- with pytest.deprecated_call(
- match="The on_startup and on_shutdown parameters are deprecated"
- ):
+ with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
app = Router(
on_startup=[run_startup],
on_shutdown=[run_shutdown],
nonlocal shutdown_called
shutdown_called = True
- with pytest.deprecated_call(
- match="The on_startup and on_shutdown parameters are deprecated"
- ):
+ with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
with pytest.warns(
- UserWarning,
- match=(
- "The `lifespan` parameter cannot be used with `on_startup` or `on_shutdown`." # noqa: E501
- ),
+ UserWarning, match="The `lifespan` parameter cannot be used with `on_startup` or `on_shutdown`."
):
- app = Router(
- on_startup=[run_startup], on_shutdown=[run_shutdown], lifespan=lifespan
- )
+ app = Router(on_startup=[run_startup], on_shutdown=[run_shutdown], lifespan=lifespan)
assert not lifespan_called
assert not startup_called
nonlocal shutdown_complete
shutdown_complete = True
- with pytest.deprecated_call(
- match="The on_startup and on_shutdown parameters are deprecated"
- ):
+ with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
app = Router(
on_startup=[run_startup],
on_shutdown=[run_shutdown],
del scope["state"]
await app(scope, receive, send)
- with pytest.raises(
- RuntimeError, match='The server does not support "state" in the lifespan scope'
- ):
+ with pytest.raises(RuntimeError, match='The server does not support "state" in the lifespan scope'):
with test_client_factory(no_state_wrapper):
raise AssertionError("Should not be called") # pragma: no cover
def run_startup() -> None:
raise RuntimeError()
- with pytest.deprecated_call(
- match="The on_startup and on_shutdown parameters are deprecated"
- ):
+ with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
router = Router(on_startup=[run_startup])
startup_failed = False
def run_shutdown() -> None:
raise RuntimeError()
- with pytest.deprecated_call(
- match="The on_startup and on_shutdown parameters are deprecated"
- ):
+ with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
app = Router(on_shutdown=[run_shutdown])
with pytest.raises(RuntimeError):
pytest.param(lambda request: ..., "<lambda>", id="lambda"),
],
)
-def test_route_name(
- endpoint: typing.Callable[..., Response], expected_name: str
-) -> None:
+def test_route_name(endpoint: typing.Callable[..., Response], expected_name: str) -> None:
assert Route(path="/", endpoint=endpoint).name == expected_name
def test_route_repr() -> None:
route = Route("/welcome", endpoint=homepage)
- assert (
- repr(route)
- == "Route(path='/welcome', name='homepage', methods=['GET', 'HEAD'])"
- )
+ assert repr(route) == "Route(path='/welcome', name='homepage', methods=['GET', 'HEAD'])"
def test_route_repr_without_methods() -> None:
)
-async def pure_asgi_echo_paths(
- scope: Scope, receive: Receive, send: Send, name: str
-) -> None:
+async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name: str) -> None:
data = {"name": name, "path": scope["path"], "root_path": scope["root_path"]}
content = json.dumps(data).encode("utf-8")
await send(
def test_paths_with_root_path(test_client_factory: TestClientFactory) -> None:
app = Starlette(routes=echo_paths_routes)
- client = test_client_factory(
- app, base_url="https://www.example.org/", root_path="/root"
- )
+ client = test_client_factory(app, base_url="https://www.example.org/", root_path="/root")
response = client.get("/root/path")
assert response.status_code == 200
assert response.json() == {
from starlette.websockets import WebSocket
from tests.types import TestClientFactory
-schemas = SchemaGenerator(
- {"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}}
-)
+schemas = SchemaGenerator({"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}})
def ws(session: WebSocket) -> None:
"get": {
"responses": {
200: {
- "description": "A list of " "organisations.",
+ "description": "A list of organisations.",
"examples": [{"name": "Foo Corp."}, {"name": "Acme Ltd."}],
}
}
},
},
"/regular-docstring-and-schema": {
- "get": {
- "responses": {
- 200: {"description": "This is included in the schema."}
- }
- }
+ "get": {"responses": {200: {"description": "This is included in the schema."}}}
},
"/subapp/subapp-endpoint": {
- "get": {
- "responses": {
- 200: {"description": "This endpoint is part of a subapp."}
- }
- }
+ "get": {"responses": {200: {"description": "This endpoint is part of a subapp."}}}
},
"/subapp2/subapp-endpoint": {
- "get": {
- "responses": {
- 200: {"description": "This endpoint is part of a subapp."}
- }
- }
+ "get": {"responses": {200: {"description": "This endpoint is part of a subapp."}}}
},
"/users": {
"get": {
}
}
},
- "post": {
- "responses": {
- 200: {"description": "A user.", "examples": {"username": "tom"}}
- }
- },
+ "post": {"responses": {200: {"description": "A user.", "examples": {"username": "tom"}}}},
},
"/users/{id}": {
"get": {
assert response.text == "<file content>"
-def test_staticfiles_with_pathlib(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_with_pathlib(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "example.txt"
with open(path, "w") as file:
file.write("<file content>")
assert response.text == "<file content>"
-def test_staticfiles_head_with_middleware(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_head_with_middleware(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
"""
see https://github.com/encode/starlette/pull/935
"""
with open(path, "w") as file:
file.write("x" * 100)
- async def does_nothing_middleware(
- request: Request, call_next: RequestResponseEndpoint
- ) -> Response:
+ async def does_nothing_middleware(request: Request, call_next: RequestResponseEndpoint) -> Response:
response = await call_next(request)
return response
assert response.text == "Method Not Allowed"
-def test_staticfiles_with_directory_returns_404(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_with_directory_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("<file content>")
assert response.text == "Not Found"
-def test_staticfiles_with_missing_file_returns_404(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_with_missing_file_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("<file content>")
assert "does not exist" in str(exc_info.value)
-def test_staticfiles_configured_with_missing_directory(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_configured_with_missing_directory(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "no_such_directory")
app = StaticFiles(directory=path, check_dir=False)
client = test_client_factory(app)
assert "is not a directory" in str(exc_info.value)
-def test_staticfiles_config_check_occurs_only_once(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_config_check_occurs_only_once(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
app = StaticFiles(directory=tmpdir)
client = test_client_factory(app)
assert not app.config_checked
assert exc_info.value.detail == "Not Found"
-def test_staticfiles_never_read_file_for_head_method(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_never_read_file_for_head_method(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("<file content>")
assert response.headers["content-length"] == "14"
-def test_staticfiles_304_with_etag_match(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_304_with_etag_match(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("<file content>")
second_resp = client.get("/example.txt", headers={"if-none-match": last_etag})
assert second_resp.status_code == 304
assert second_resp.content == b""
- second_resp = client.get(
- "/example.txt", headers={"if-none-match": f'W/{last_etag}, "123"'}
- )
+ second_resp = client.get("/example.txt", headers={"if-none-match": f'W/{last_etag}, "123"'})
assert second_resp.status_code == 304
assert second_resp.content == b""
tmpdir: Path, test_client_factory: TestClientFactory
) -> None:
path = os.path.join(tmpdir, "example.txt")
- file_last_modified_time = time.mktime(
- time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S")
- )
+ file_last_modified_time = time.mktime(time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S"))
with open(path, "w") as file:
file.write("<file content>")
os.utime(path, (file_last_modified_time, file_last_modified_time))
app = StaticFiles(directory=tmpdir)
client = test_client_factory(app)
# last modified less than last request, 304
- response = client.get(
- "/example.txt", headers={"If-Modified-Since": "Thu, 11 Oct 2013 15:30:19 GMT"}
- )
+ response = client.get("/example.txt", headers={"If-Modified-Since": "Thu, 11 Oct 2013 15:30:19 GMT"})
assert response.status_code == 304
assert response.content == b""
# last modified greater than last request, 200 with content
- response = client.get(
- "/example.txt", headers={"If-Modified-Since": "Thu, 20 Feb 2012 15:30:19 GMT"}
- )
+ response = client.get("/example.txt", headers={"If-Modified-Since": "Thu, 20 Feb 2012 15:30:19 GMT"})
assert response.status_code == 200
assert response.content == b"<file content>"
-def test_staticfiles_html_normal(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_html_normal(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "404.html")
with open(path, "w") as file:
file.write("<h1>Custom not found page</h1>")
assert response.text == "<h1>Custom not found page</h1>"
-def test_staticfiles_html_without_index(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_html_without_index(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "404.html")
with open(path, "w") as file:
file.write("<h1>Custom not found page</h1>")
assert response.text == "<h1>Custom not found page</h1>"
-def test_staticfiles_html_without_404(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_html_without_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "dir")
os.mkdir(path)
path = os.path.join(path, "index.html")
assert exc_info.value.status_code == 404
-def test_staticfiles_html_only_files(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_html_only_files(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "hello.html")
with open(path, "w") as file:
file.write("<h1>Hello</h1>")
with open(path_some, "w") as file:
file.write("<p>some file</p>")
- common_modified_time = time.mktime(
- time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S")
- )
+ common_modified_time = time.mktime(time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S"))
os.utime(path_404, (common_modified_time, common_modified_time))
os.utime(path_some, (common_modified_time, common_modified_time))
tmp_path.chmod(original_mode)
-def test_staticfiles_with_missing_dir_returns_404(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_with_missing_dir_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("<file content>")
assert response.text == "Not Found"
-def test_staticfiles_access_file_as_dir_returns_404(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_access_file_as_dir_returns_404(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("<file content>")
assert response.text == "Not Found"
-def test_staticfiles_filename_too_long(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_filename_too_long(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")]
app = Starlette(routes=routes)
client = test_client_factory(app)
assert response.text == "Internal Server Error"
-def test_staticfiles_follows_symlinks(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_follows_symlinks(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
statics_path = os.path.join(tmpdir, "statics")
os.mkdir(statics_path)
assert response.text == "<h1>Hello</h1>"
-def test_staticfiles_follows_symlink_directories(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_staticfiles_follows_symlink_directories(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
statics_path = os.path.join(tmpdir, "statics")
statics_html_path = os.path.join(statics_path, "html")
os.mkdir(statics_path)
(
(
"WS_1004_NO_STATUS_RCVD",
- "'WS_1004_NO_STATUS_RCVD' is deprecated. "
- "Use 'WS_1005_NO_STATUS_RCVD' instead.",
+ "'WS_1004_NO_STATUS_RCVD' is deprecated. Use 'WS_1005_NO_STATUS_RCVD' instead.",
),
(
"WS_1005_ABNORMAL_CLOSURE",
- "'WS_1005_ABNORMAL_CLOSURE' is deprecated. "
- "Use 'WS_1006_ABNORMAL_CLOSURE' instead.",
+ "'WS_1005_ABNORMAL_CLOSURE' is deprecated. Use 'WS_1006_ABNORMAL_CLOSURE' instead.",
),
),
)
assert set(response.context.keys()) == {"request"} # type: ignore
-def test_calls_context_processors(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_calls_context_processors(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "index.html"
path.write_text("<html>Hello {{ username }}</html>")
assert set(response.context.keys()) == {"request", "username"} # type: ignore
-def test_template_with_middleware(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_template_with_middleware(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>")
return templates.TemplateResponse(request, "index.html")
class CustomMiddleware(BaseHTTPMiddleware):
- async def dispatch(
- self, request: Request, call_next: RequestResponseEndpoint
- ) -> Response:
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
return await call_next(request)
app = Starlette(
assert set(response.context.keys()) == {"request"} # type: ignore
-def test_templates_with_directories(
- tmp_path: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_templates_with_directories(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
dir_a = tmp_path.resolve() / "a"
dir_a.mkdir()
template_a = dir_a / "template_a.html"
def test_templates_require_directory_or_environment() -> None:
- with pytest.raises(
- AssertionError, match="either 'directory' or 'env' arguments must be passed"
- ):
+ with pytest.raises(AssertionError, match="either 'directory' or 'env' arguments must be passed"):
Jinja2Templates() # type: ignore[call-overload]
def test_templates_require_directory_or_enviroment_not_both() -> None:
- with pytest.raises(
- AssertionError, match="either 'directory' or 'env' arguments must be passed"
- ):
+ with pytest.raises(AssertionError, match="either 'directory' or 'env' arguments must be passed"):
Jinja2Templates(directory="dir", env=jinja2.Environment())
assert template.render({}) == "Hello"
-def test_templates_with_environment(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_templates_with_environment(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>")
Jinja2Templates(str(tmpdir), autoescape=True)
-def test_templates_with_kwargs_only(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_templates_with_kwargs_only(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
# MAINTAINERS: remove after 1.0
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
templates = Jinja2Templates(directory=str(tmpdir))
def page(request: Request) -> Response:
- return templates.TemplateResponse(
- name="index.html", context={"request": request}
- )
+ return templates.TemplateResponse(name="index.html", context={"request": request})
app = Starlette(routes=[Route("/", page)])
client = test_client_factory(app)
spy.assert_called()
-def test_templates_when_first_argument_is_request(
- tmpdir: Path, test_client_factory: TestClientFactory
-) -> None:
+def test_templates_when_first_argument_is_request(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
# MAINTAINERS: remove after 1.0
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
assert client.headers.get("Authentication") == "Bearer 123"
-def test_use_testclient_as_contextmanager(
- test_client_factory: TestClientFactory, anyio_backend_name: str
-) -> None:
+def test_use_testclient_as_contextmanager(test_client_factory: TestClientFactory, anyio_backend_name: str) -> None:
"""
This test asserts a number of properties that are important for an
app level task_group
def test_error_on_startup(test_client_factory: TestClientFactory) -> None:
- with pytest.deprecated_call(
- match="The on_startup and on_shutdown parameters are deprecated"
- ):
+ with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
startup_error_app = Starlette(on_startup=[startup])
with pytest.raises(RuntimeError):
marks=[
pytest.mark.xfail(
sys.version_info < (3, 11),
- reason="Fails due to domain handling in http.cookiejar module (see "
- "#2152)",
+ reason="Fails due to domain handling in http.cookiejar module (see #2152)",
),
],
),
("example.com", False),
],
)
-def test_domain_restricted_cookies(
- test_client_factory: TestClientFactory, domain: str, ok: bool
-) -> None:
+def test_domain_restricted_cookies(test_client_factory: TestClientFactory, domain: str, ok: bool) -> None:
"""
Test that test client discards domain restricted cookies which do not match the
base_url of the testclient (`http://testserver` by default).
async def send(message: Message) -> None:
if message["type"] == "websocket.accept":
return
- # Simulate the exception the server would send to the application when the
- # client disconnects.
+ # Simulate the exception the server would send to the application when the client disconnects.
raise OSError
with pytest.raises(WebSocketDisconnect) as ctx:
"headers": [(b"content-type", b"text/plain"), (b"foo", b"bar")],
}
)
- await websocket.send(
- {
- "type": "websocket.http.response.body",
- "body": b"hard",
- "more_body": True,
- }
- )
- await websocket.send(
- {
- "type": "websocket.http.response.body",
- "body": b"body",
- }
- )
+ await websocket.send({"type": "websocket.http.response.body", "body": b"hard", "more_body": True})
+ await websocket.send({"type": "websocket.http.response.body", "body": b"body"})
client = test_client_factory(app)
with pytest.raises(WebSocketDenialResponse) as exc:
client = test_client_factory(app)
with pytest.raises(
RuntimeError,
- match=(
- 'Expected ASGI message "websocket.http.response.body", but got '
- "'websocket.http.response.start'"
- ),
+ match=("Expected ASGI message \"websocket.http.response.body\", but got 'websocket.http.response.start'"),
):
with client.websocket_connect("/"):
pass # pragma: no cover
async def mock_send(message: Message) -> None: ... # pragma: no cover
- websocket = WebSocket(
- {"type": "websocket", "path": "/abc/", "headers": []},
- receive=mock_receive,
- send=mock_send,
- )
+ websocket = WebSocket({"type": "websocket", "path": "/abc/", "headers": []}, receive=mock_receive, send=mock_send)
assert websocket["type"] == "websocket"
assert dict(websocket) == {"type": "websocket", "path": "/abc/", "headers": []}
assert len(websocket) == 3