import sys
-from typing import List, Optional
+from typing import Any, Callable, Dict, Iterator, List, Optional
import anyio
import pytest
from starlette.datastructures import Address, State
from starlette.requests import ClientDisconnect, Request
from starlette.responses import JSONResponse, PlainTextResponse, Response
-from starlette.types import Message, Scope
+from starlette.testclient import TestClient
+from starlette.types import Message, Receive, Scope, Send
+TestClientFactory = Callable[..., TestClient]
-def test_request_url(test_client_factory):
- async def app(scope, receive, send):
+
+def test_request_url(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = {"method": request.method, "url": str(request.url)}
response = JSONResponse(data)
assert response.json() == {"method": "GET", "url": "https://example.org:123/"}
-def test_request_query_params(test_client_factory):
- async def app(scope, receive, send):
+def test_request_query_params(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
params = dict(request.query_params)
response = JSONResponse({"params": params})
any(module in sys.modules for module in ("brotli", "brotlicffi")),
reason='urllib3 includes "br" to the "accept-encoding" headers.',
)
-def test_request_headers(test_client_factory):
- async def app(scope, receive, send):
+def test_request_headers(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
headers = dict(request.headers)
response = JSONResponse({"headers": headers})
({}, None),
],
)
-def test_request_client(scope: Scope, expected_client: Optional[Address]):
+def test_request_client(scope: Scope, expected_client: Optional[Address]) -> None:
scope.update({"type": "http"}) # required by Request's constructor
client = Request(scope).client
assert client == expected_client
-def test_request_body(test_client_factory):
- async def app(scope, receive, send):
+def test_request_body(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
body = await request.body()
response = JSONResponse({"body": body.decode()})
response = client.post("/", json={"a": "123"})
assert response.json() == {"body": '{"a": "123"}'}
- response = client.post("/", data="abc")
+ response = client.post("/", data="abc") # type: ignore
assert response.json() == {"body": "abc"}
-def test_request_stream(test_client_factory):
- async def app(scope, receive, send):
+def test_request_stream(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
body = b""
async for chunk in request.stream():
response = client.post("/", json={"a": "123"})
assert response.json() == {"body": '{"a": "123"}'}
- response = client.post("/", data="abc")
+ response = client.post("/", data="abc") # type: ignore
assert response.json() == {"body": "abc"}
-def test_request_form_urlencoded(test_client_factory):
- async def app(scope, receive, send):
+def test_request_form_urlencoded(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
form = await request.form()
response = JSONResponse({"form": dict(form)})
assert response.json() == {"form": {"abc": "123 @"}}
-def test_request_form_context_manager(test_client_factory):
- async def app(scope, receive, send):
+def test_request_form_context_manager(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
async with request.form() as form:
response = JSONResponse({"form": dict(form)})
assert response.json() == {"form": {"abc": "123 @"}}
-def test_request_body_then_stream(test_client_factory):
- async def app(scope, receive, send):
+def test_request_body_then_stream(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
body = await request.body()
chunks = b""
client = test_client_factory(app)
- response = client.post("/", data="abc")
+ response = client.post("/", data="abc") # type: ignore
assert response.json() == {"body": "abc", "stream": "abc"}
-def test_request_stream_then_body(test_client_factory):
- async def app(scope, receive, send):
+def test_request_stream_then_body(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
chunks = b""
async for chunk in request.stream():
client = test_client_factory(app)
- response = client.post("/", data="abc")
+ response = client.post("/", data="abc") # type: ignore
assert response.json() == {"body": "<stream consumed>", "stream": "abc"}
-def test_request_json(test_client_factory):
- async def app(scope, receive, send):
+def test_request_json(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = await request.json()
response = JSONResponse({"json": data})
assert response.json() == {"json": {"a": "123"}}
-def test_request_scope_interface():
+def test_request_scope_interface() -> None:
"""
A Request can be instantiated with a scope, and presents a `Mapping`
interface.
assert len(request) == 3
-def test_request_raw_path(test_client_factory):
- async def app(scope, receive, send):
+def test_request_raw_path(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
path = request.scope["path"]
raw_path = request.scope["raw_path"]
assert response.text == "/he/llo, b'/he%2Fllo'"
-def test_request_without_setting_receive(test_client_factory):
+def test_request_without_setting_receive(
+ test_client_factory: TestClientFactory,
+) -> None:
"""
If Request is instantiated without the receive channel, then .body()
is not available.
"""
- async def app(scope, receive, send):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope)
try:
data = await request.json()
assert response.json() == {"json": "Receive channel not available"}
-def test_request_disconnect(anyio_backend_name, anyio_backend_options):
+def test_request_disconnect(
+ anyio_backend_name: str,
+ anyio_backend_options: Dict[str, Any],
+) -> None:
"""
If a client disconnect occurs while reading request body
then ClientDisconnect should be raised.
"""
- async def app(scope, receive, send):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
await request.body()
- async def receiver():
+ async def receiver() -> Message:
return {"type": "http.disconnect"}
scope = {"type": "http", "method": "POST", "path": "/"}
with pytest.raises(ClientDisconnect):
anyio.run(
- app,
+ app, # type: ignore
scope,
receiver,
None,
)
-def test_request_is_disconnected(test_client_factory):
+def test_request_is_disconnected(test_client_factory: TestClientFactory) -> None:
"""
If a client disconnect occurs while reading request body
then ClientDisconnect should be raised.
"""
disconnected_after_response = None
- async def app(scope, receive, send):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
nonlocal disconnected_after_response
request = Request(scope, receive)
assert disconnected_after_response
-def test_request_state_object():
+def test_request_state_object() -> None:
scope = {"state": {"old": "foo"}}
s = State(scope["state"])
s.new
-def test_request_state(test_client_factory):
- async def app(scope, receive, send):
+def test_request_state(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
request.state.example = 123
response = JSONResponse({"state.example": request.state.example})
assert response.json() == {"state.example": 123}
-def test_request_cookies(test_client_factory):
- async def app(scope, receive, send):
+def test_request_cookies(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
mycookie = request.cookies.get("mycookie")
if mycookie:
assert response.text == "Hello, cookies!"
-def test_cookie_lenient_parsing(test_client_factory):
+def test_cookie_lenient_parsing(test_client_factory: TestClientFactory) -> None:
"""
The following test is based on a cookie set by Okta, a well-known authorization
service. It turns out that it's common practice to set cookies that would be
"sessionCookie",
}
- async def app(scope, receive, send):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
response = JSONResponse({"cookies": request.cookies})
await response(scope, receive, send)
("a=b; h=i; a=c", {"a": "c", "h": "i"}),
],
)
-def test_cookies_edge_cases(set_cookie, expected, test_client_factory):
- async def app(scope, receive, send):
+def test_cookies_edge_cases(
+ set_cookie: str,
+ expected: Dict[str, str],
+ test_client_factory: TestClientFactory,
+) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
response = JSONResponse({"cookies": request.cookies})
await response(scope, receive, send)
# (" = b ; ; = ; c = ; ", {"": "b", "c": ""}),
],
)
-def test_cookies_invalid(set_cookie, expected, test_client_factory):
+def test_cookies_invalid(
+ set_cookie: str,
+ expected: Dict[str, str],
+ test_client_factory: TestClientFactory,
+) -> None:
"""
Cookie strings that are against the RFC6265 spec but which browsers will send if set
via document.cookie.
"""
- async def app(scope, receive, send):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
response = JSONResponse({"cookies": request.cookies})
await response(scope, receive, send)
assert result["cookies"] == expected
-def test_chunked_encoding(test_client_factory):
- async def app(scope, receive, send):
+def test_chunked_encoding(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
body = await request.body()
response = JSONResponse({"body": body.decode()})
client = test_client_factory(app)
- def post_body():
+ def post_body() -> Iterator[bytes]:
yield b"foo"
yield b"bar"
- response = client.post("/", data=post_body())
+ response = client.post("/", data=post_body()) # type: ignore
assert response.json() == {"body": "foobar"}
-def test_request_send_push_promise(test_client_factory):
- async def app(scope, receive, send):
+def test_request_send_push_promise(test_client_factory: TestClientFactory) -> None:
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
# the server is push-enabled
scope["extensions"]["http.response.push"] = {}
assert response.json() == {"json": "OK"}
-def test_request_send_push_promise_without_push_extension(test_client_factory):
+def test_request_send_push_promise_without_push_extension(
+ test_client_factory: TestClientFactory,
+) -> None:
"""
If server does not support the `http.response.push` extension,
.send_push_promise() does nothing.
"""
- async def app(scope, receive, send):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope)
await request.send_push_promise("/style.css")
assert response.json() == {"json": "OK"}
-def test_request_send_push_promise_without_setting_send(test_client_factory):
+def test_request_send_push_promise_without_setting_send(
+ test_client_factory: TestClientFactory,
+) -> None:
"""
If Request is instantiated without the send channel, then
.send_push_promise() is not available.
"""
- async def app(scope, receive, send):
+ async def app(scope: Scope, receive: Receive, send: Send) -> None:
# the server is push-enabled
scope["extensions"]["http.response.push"] = {}