import time
from http.cookies import SimpleCookie
from pathlib import Path
+from typing import AsyncIterator, Callable, Iterator, Union
import anyio
import pytest
StreamingResponse,
)
from starlette.testclient import TestClient
-from starlette.types import Message
+from starlette.types import Message, Receive, Scope, Send
+TestClientFactory = Callable[..., TestClient]
-def test_text_response(test_client_factory):
- async def app(scope, receive, send):
+
+def test_text_response(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("hello, world", media_type="text/plain")
await response(scope, receive, send)
assert response.text == "hello, world"
-def test_bytes_response(test_client_factory):
- async def app(scope, receive, send):
+def test_bytes_response(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response(b"xxxxx", media_type="image/png")
await response(scope, receive, send)
assert response.content == b"xxxxx"
-def test_json_none_response(test_client_factory):
- async def app(scope, receive, send):
+def test_json_none_response(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = JSONResponse(None)
await response(scope, receive, send)
assert response.content == b"null"
-def test_redirect_response(test_client_factory):
- async def app(scope, receive, send):
+def test_redirect_response(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
if scope["path"] == "/":
response = Response("hello, world", media_type="text/plain")
else:
assert response.url == "http://testserver/"
-def test_quoting_redirect_response(test_client_factory):
- async def app(scope, receive, send):
+def test_quoting_redirect_response(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
if scope["path"] == "/I ♥ Starlette/":
response = Response("hello, world", media_type="text/plain")
else:
assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/"
-def test_redirect_response_content_length_header(test_client_factory):
- async def app(scope, receive, send):
+def test_redirect_response_content_length_header(
+ test_client_factory: TestClientFactory,
+) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
if scope["path"] == "/":
response = Response("hello", media_type="text/plain") # pragma: nocover
else:
assert response.headers["content-length"] == "0"
-def test_streaming_response(test_client_factory):
+def test_streaming_response(test_client_factory: TestClientFactory) -> None:
filled_by_bg_task = ""
- async def app(scope, receive, send):
- async def numbers(minimum, maximum):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]:
for i in range(minimum, maximum + 1):
yield str(i)
if i != maximum:
yield ", "
await anyio.sleep(0)
- async def numbers_for_cleanup(start=1, stop=5):
+ async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None:
nonlocal filled_by_bg_task
async for thing in numbers(start, stop):
filled_by_bg_task = filled_by_bg_task + thing
assert filled_by_bg_task == "6, 7, 8, 9"
-def test_streaming_response_custom_iterator(test_client_factory):
- async def app(scope, receive, send):
+def test_streaming_response_custom_iterator(
+ test_client_factory: TestClientFactory,
+) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
class CustomAsyncIterator:
- def __init__(self):
+ def __init__(self) -> None:
self._called = 0
- def __aiter__(self):
+ def __aiter__(self) -> AsyncIterator[str]:
return self
- async def __anext__(self):
+ async def __anext__(self) -> str:
if self._called == 5:
raise StopAsyncIteration()
self._called += 1
assert response.text == "12345"
-def test_streaming_response_custom_iterable(test_client_factory):
- async def app(scope, receive, send):
+def test_streaming_response_custom_iterable(
+ test_client_factory: TestClientFactory,
+) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
class CustomAsyncIterable:
- async def __aiter__(self):
+ async def __aiter__(self) -> AsyncIterator[Union[str, bytes]]:
for i in range(5):
yield str(i + 1)
assert response.text == "12345"
-def test_sync_streaming_response(test_client_factory):
- async def app(scope, receive, send):
- def numbers(minimum, maximum):
+def test_sync_streaming_response(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
+ def numbers(minimum: int, maximum: int) -> Iterator[str]:
for i in range(minimum, maximum + 1):
yield str(i)
if i != maximum:
assert response.text == "1, 2, 3, 4, 5"
-def test_response_headers(test_client_factory):
- async def app(scope, receive, send):
+def test_response_headers(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
headers = {"x-header-1": "123", "x-header-2": "456"}
response = Response("hello, world", media_type="text/plain", headers=headers)
response.headers["x-header-2"] = "789"
assert response.headers["x-header-2"] == "789"
-def test_response_phrase(test_client_factory):
+def test_response_phrase(test_client_factory: TestClientFactory) -> None:
app = Response(status_code=204)
client = test_client_factory(app)
response = client.get("/")
assert response.reason_phrase == ""
-def test_file_response(tmpdir, test_client_factory):
+def test_file_response(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "xyz")
content = b"<file content>" * 1000
with open(path, "wb") as file:
filled_by_bg_task = ""
- async def numbers(minimum, maximum):
+ async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]:
for i in range(minimum, maximum + 1):
yield str(i)
if i != maximum:
yield ", "
await anyio.sleep(0)
- async def numbers_for_cleanup(start=1, stop=5):
+ async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None:
nonlocal filled_by_bg_task
async for thing in numbers(start, stop):
filled_by_bg_task = filled_by_bg_task + thing
cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9)
- async def app(scope, receive, send):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = FileResponse(
path=path, filename="example.png", background=cleanup_task
)
@pytest.mark.anyio
-async def test_file_response_on_head_method(tmpdir: Path):
+async def test_file_response_on_head_method(tmpdir: Path) -> None:
path = os.path.join(tmpdir, "xyz")
content = b"<file content>" * 1000
with open(path, "wb") as file:
await app({"type": "http", "method": "head"}, receive, send)
-def test_file_response_with_directory_raises_error(tmpdir, test_client_factory):
+def test_file_response_with_directory_raises_error(
+ tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
app = FileResponse(path=tmpdir, filename="example.png")
client = test_client_factory(app)
with pytest.raises(RuntimeError) as exc_info:
assert "is not a file" in str(exc_info.value)
-def test_file_response_with_missing_file_raises_error(tmpdir, test_client_factory):
+def test_file_response_with_missing_file_raises_error(
+ tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
path = os.path.join(tmpdir, "404.txt")
app = FileResponse(path=path, filename="404.txt")
client = test_client_factory(app)
assert "does not exist" in str(exc_info.value)
-def test_file_response_with_chinese_filename(tmpdir, test_client_factory):
+def test_file_response_with_chinese_filename(
+ tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
content = b"file content"
filename = "你好.txt" # probably "Hello.txt" in Chinese
path = os.path.join(tmpdir, filename)
assert response.headers["content-disposition"] == expected_disposition
-def test_file_response_with_inline_disposition(tmpdir, test_client_factory):
+def test_file_response_with_inline_disposition(
+ tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
content = b"file content"
filename = "hello.txt"
path = os.path.join(tmpdir, filename)
assert response.headers["content-disposition"] == expected_disposition
-def test_file_response_with_method_warns(tmpdir, test_client_factory):
+def test_file_response_with_method_warns(
+ tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
with pytest.warns(DeprecationWarning):
FileResponse(path=tmpdir, filename="example.png", method="GET")
@pytest.mark.anyio
-async def test_file_response_with_pathsend(tmpdir: Path):
+async def test_file_response_with_pathsend(tmpdir: Path) -> None:
path = os.path.join(tmpdir, "xyz")
content = b"<file content>" * 1000
with open(path, "wb") as file:
)
-def test_set_cookie(test_client_factory, monkeypatch):
+def test_set_cookie(
+ test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch
+) -> None:
# Mock time used as a reference for `Expires` by stdlib `SimpleCookie`.
mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc)
monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp())
- async def app(scope, receive, send):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("Hello, world!", media_type="text/plain")
response.set_cookie(
"mycookie",
pytest.param(10, id="int"),
],
)
-def test_expires_on_set_cookie(test_client_factory, monkeypatch, expires):
+def test_expires_on_set_cookie(
+ test_client_factory: TestClientFactory,
+ monkeypatch: pytest.MonkeyPatch,
+ expires: str,
+) -> None:
# Mock time used as a reference for `Expires` by stdlib `SimpleCookie`.
mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc)
monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp())
- async def app(scope, receive, send):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("Hello, world!", media_type="text/plain")
response.set_cookie("mycookie", "myvalue", expires=expires)
await response(scope, receive, send)
assert cookie["mycookie"]["expires"] == "Thu, 22 Jan 2037 12:00:10 GMT"
-def test_delete_cookie(test_client_factory):
- async def app(scope, receive, send):
+def test_delete_cookie(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
response = Response("Hello, world!", media_type="text/plain")
if request.cookies.get("mycookie"):
assert not response.cookies.get("mycookie")
-def test_populate_headers(test_client_factory):
+def test_populate_headers(test_client_factory: TestClientFactory) -> None:
app = Response(content="hi", headers={}, media_type="text/html")
client = test_client_factory(app)
response = client.get("/")
assert response.headers["content-type"] == "text/html; charset=utf-8"
-def test_head_method(test_client_factory):
+def test_head_method(test_client_factory: TestClientFactory) -> None:
app = Response("hello, world", media_type="text/plain")
client = test_client_factory(app)
response = client.head("/")
assert response.text == ""
-def test_empty_response(test_client_factory):
+def test_empty_response(test_client_factory: TestClientFactory) -> None:
app = Response()
client: TestClient = test_client_factory(app)
response = client.get("/")
assert "content-type" not in response.headers
-def test_empty_204_response(test_client_factory):
+def test_empty_204_response(test_client_factory: TestClientFactory) -> None:
app = Response(status_code=204)
client: TestClient = test_client_factory(app)
response = client.get("/")
assert "content-length" not in response.headers
-def test_non_empty_response(test_client_factory):
+def test_non_empty_response(test_client_factory: TestClientFactory) -> None:
app = Response(content="hi")
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.headers["content-length"] == "2"
-def test_response_do_not_add_redundant_charset(test_client_factory):
+def test_response_do_not_add_redundant_charset(
+ test_client_factory: TestClientFactory,
+) -> None:
app = Response(media_type="text/plain; charset=utf-8")
client = test_client_factory(app)
response = client.get("/")
assert response.headers["content-type"] == "text/plain; charset=utf-8"
-def test_file_response_known_size(tmpdir, test_client_factory):
+def test_file_response_known_size(
+ tmpdir: Path, test_client_factory: TestClientFactory
+) -> None:
path = os.path.join(tmpdir, "xyz")
content = b"<file content>" * 1000
with open(path, "wb") as file:
assert response.headers["content-length"] == str(len(content))
-def test_streaming_response_unknown_size(test_client_factory):
+def test_streaming_response_unknown_size(
+ test_client_factory: TestClientFactory,
+) -> None:
app = StreamingResponse(content=iter(["hello", "world"]))
client: TestClient = test_client_factory(app)
response = client.get("/")
assert "content-length" not in response.headers
-def test_streaming_response_known_size(test_client_factory):
+def test_streaming_response_known_size(test_client_factory: TestClientFactory) -> None:
app = StreamingResponse(
content=iter(["hello", "world"]), headers={"content-length": "10"}
)
@pytest.mark.anyio
-async def test_streaming_response_stops_if_receiving_http_disconnect():
+async def test_streaming_response_stops_if_receiving_http_disconnect() -> None:
streamed = 0
disconnected = anyio.Event()
- async def receive_disconnect():
+ async def receive_disconnect() -> Message:
await disconnected.wait()
return {"type": "http.disconnect"}
- async def send(message):
+ async def send(message: Message) -> None:
nonlocal streamed
if message["type"] == "http.response.body":
streamed += len(message.get("body", b""))
if streamed >= 16:
disconnected.set()
- async def stream_indefinitely():
+ async def stream_indefinitely() -> AsyncIterator[bytes]:
while True:
# Need a sleep for the event loop to switch to another task
await anyio.sleep(0)