)
break
except asyncio.TimeoutError:
- # We check our flag at the possible moment, in order to
+ # We check our flag at the first possible moment, in order to
# allow us to suppress write timeouts, if we've since
# switched over to read-timeout mode.
should_raise = flag is None or flag.raise_on_write_timeout
-import asyncio
-import functools
-
import pytest
import httpx
-def threadpool(func):
- """
- Our sync tests should run in separate thread to the uvicorn server.
- """
-
- @functools.wraps(func)
- async def wrapped(*args, **kwargs):
- nonlocal func
-
- loop = asyncio.get_event_loop()
- if kwargs:
- func = functools.partial(func, **kwargs)
- await loop.run_in_executor(None, func, *args)
-
- return pytest.mark.asyncio(wrapped)
-
-
-@threadpool
def test_get(server):
url = "http://127.0.0.1:8000/"
with httpx.Client() as http:
assert repr(response) == "<Response [200 OK]>"
-@threadpool
def test_post(server):
with httpx.Client() as http:
response = http.post("http://127.0.0.1:8000/", data=b"Hello, world!")
assert response.reason_phrase == "OK"
-@threadpool
def test_post_json(server):
with httpx.Client() as http:
response = http.post("http://127.0.0.1:8000/", json={"text": "Hello, world!"})
assert response.reason_phrase == "OK"
-@threadpool
def test_stream_response(server):
with httpx.Client() as http:
response = http.get("http://127.0.0.1:8000/", stream=True)
assert content == b"Hello, world!"
-@threadpool
def test_stream_iterator(server):
with httpx.Client() as http:
response = http.get("http://127.0.0.1:8000/", stream=True)
assert body == b"Hello, world!"
-@threadpool
def test_raw_iterator(server):
with httpx.Client() as http:
response = http.get("http://127.0.0.1:8000/", stream=True)
response.close() # TODO: should Response be available as context managers?
-@threadpool
def test_raise_for_status(server):
with httpx.Client() as client:
for status_code in (200, 400, 404, 500, 505):
assert response.raise_for_status() is None
-@threadpool
def test_options(server):
with httpx.Client() as http:
response = http.options("http://127.0.0.1:8000/")
assert response.reason_phrase == "OK"
-@threadpool
def test_head(server):
with httpx.Client() as http:
response = http.head("http://127.0.0.1:8000/")
assert response.reason_phrase == "OK"
-@threadpool
def test_put(server):
with httpx.Client() as http:
response = http.put("http://127.0.0.1:8000/", data=b"Hello, world!")
assert response.reason_phrase == "OK"
-@threadpool
def test_patch(server):
with httpx.Client() as http:
response = http.patch("http://127.0.0.1:8000/", data=b"Hello, world!")
assert response.reason_phrase == "OK"
-@threadpool
def test_delete(server):
with httpx.Client() as http:
response = http.delete("http://127.0.0.1:8000/")
assert response.reason_phrase == "OK"
-@threadpool
def test_base_url(server):
base_url = "http://127.0.0.1:8000/"
with httpx.Client(base_url=base_url) as http:
import asyncio
+import threading
+import time
import pytest
import trustme
)
-@pytest.fixture
+SERVER_SCOPE = "session"
+
+
+@pytest.fixture(scope=SERVER_SCOPE)
def example_cert():
ca = CAWithPKEncryption()
ca.issue_cert("example.org")
return ca
-@pytest.fixture
+@pytest.fixture(scope=SERVER_SCOPE)
def cert_pem_file(example_cert):
with example_cert.cert_pem.tempfile() as tmp:
yield tmp
-@pytest.fixture
+@pytest.fixture(scope=SERVER_SCOPE)
def cert_private_key_file(example_cert):
with example_cert.private_key_pem.tempfile() as tmp:
yield tmp
-@pytest.fixture
+@pytest.fixture(scope=SERVER_SCOPE)
def cert_encrypted_private_key_file(example_cert):
with example_cert.encrypted_private_key_pem.tempfile() as tmp:
yield tmp
+class TestServer(Server):
+ def install_signal_handlers(self) -> None:
+ # Disable the default installation of handlers for signals such as SIGTERM,
+ # because it can only be done in the main thread.
+ pass
+
+ async def serve(self, sockets=None):
+ self.restart_requested = asyncio.Event()
+
+ loop = asyncio.get_event_loop()
+ tasks = {
+ loop.create_task(super().serve(sockets=sockets)),
+ loop.create_task(self.watch_restarts()),
+ }
+
+ await asyncio.wait(tasks)
+
+ async def restart(self) -> None:
+ # Ensure we are in an asyncio environment.
+ assert asyncio.get_event_loop() is not None
+ # This may be called from a different thread than the one the server is
+ # running on. For this reason, we use an event to coordinate with the server
+ # instead of calling shutdown()/startup() directly.
+ self.restart_requested.set()
+ self.started = False
+ while not self.started:
+ await asyncio.sleep(0.5)
+
+ async def watch_restarts(self):
+ while True:
+ if self.should_exit:
+ return
+
+ try:
+ await asyncio.wait_for(self.restart_requested.wait(), timeout=0.1)
+ except asyncio.TimeoutError:
+ continue
+
+ self.restart_requested.clear()
+ await self.shutdown()
+ await self.startup()
+
+
@pytest.fixture
-async def server():
- config = Config(app=app, lifespan="off")
- server = Server(config=config)
- task = asyncio.ensure_future(server.serve())
+def restart(backend):
+ """Restart the running server from an async test function.
+
+ This fixture deals with possible differences between the environment of the
+ test function and that of the server.
+ """
+
+ async def restart(server):
+ await backend.run_in_threadpool(AsyncioBackend().run, server.restart)
+
+ return restart
+
+
+def serve_in_thread(server: Server):
+ thread = threading.Thread(target=server.run)
+ thread.start()
try:
while not server.started:
- await asyncio.sleep(0.0001)
+ time.sleep(1e-3)
yield server
finally:
server.should_exit = True
- await task
+ thread.join()
-@pytest.fixture
-async def https_server(cert_pem_file, cert_private_key_file):
+@pytest.fixture(scope=SERVER_SCOPE)
+def server():
+ config = Config(app=app, lifespan="off", loop="asyncio")
+ server = TestServer(config=config)
+ yield from serve_in_thread(server)
+
+
+@pytest.fixture(scope=SERVER_SCOPE)
+def https_server(cert_pem_file, cert_private_key_file):
config = Config(
app=app,
lifespan="off",
ssl_certfile=cert_pem_file,
ssl_keyfile=cert_private_key_file,
port=8001,
+ loop="asyncio",
)
- server = Server(config=config)
- task = asyncio.ensure_future(server.serve())
- try:
- while not server.started:
- await asyncio.sleep(0.0001)
- yield server
- finally:
- server.should_exit = True
- await task
-
-
-@pytest.fixture
-def restart(backend):
- async def asyncio_restart(server):
- await server.shutdown()
- await server.startup()
-
- if isinstance(backend, AsyncioBackend):
- return asyncio_restart
-
- # The uvicorn server runs under asyncio, so we will need to figure out
- # how to restart it under a different I/O library.
- # This will most likely require running `asyncio_restart` in the threadpool,
- # but that might not be sufficient.
- raise NotImplementedError
+ server = TestServer(config=config)
+ yield from serve_in_thread(server)
response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()
- # shutdown the server to close the keep-alive connection
+ # Shutdown the server to close the keep-alive connection
await restart(server)
response = await http.request("GET", "http://127.0.0.1:8000/")
response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()
- # shutdown the server to close the keep-alive connection
+ # Shutdown the server to close the keep-alive connection
await restart(server)
response = await http.request("GET", "http://127.0.0.1:8000/")
async def test_get(server, backend):
- conn = HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend)
- response = await conn.request("GET", "http://127.0.0.1:8000/")
- await response.read()
- assert response.status_code == 200
- assert response.content == b"Hello, world!"
+ async with HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend) as conn:
+ response = await conn.request("GET", "http://127.0.0.1:8000/")
+ await response.read()
+ assert response.status_code == 200
+ assert response.content == b"Hello, world!"
async def test_post(server, backend):
- conn = HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend)
- response = await conn.request(
- "GET", "http://127.0.0.1:8000/", data=b"Hello, world!"
- )
- assert response.status_code == 200
+ async with HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend) as conn:
+ response = await conn.request(
+ "GET", "http://127.0.0.1:8000/", data=b"Hello, world!"
+ )
+ assert response.status_code == 200
async def test_https_get_with_ssl_defaults(https_server, backend):
"""
An HTTPS request, with default SSL configuration set on the client.
"""
- conn = HTTPConnection(
+ async with HTTPConnection(
origin="https://127.0.0.1:8001/", verify=False, backend=backend
- )
- response = await conn.request("GET", "https://127.0.0.1:8001/")
- await response.read()
- assert response.status_code == 200
- assert response.content == b"Hello, world!"
+ ) as conn:
+ response = await conn.request("GET", "https://127.0.0.1:8001/")
+ await response.read()
+ assert response.status_code == 200
+ assert response.content == b"Hello, world!"
async def test_https_get_with_sll_overrides(https_server, backend):
"""
An HTTPS request, with SSL configuration set on the request.
"""
- conn = HTTPConnection(origin="https://127.0.0.1:8001/", backend=backend)
- response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False)
- await response.read()
- assert response.status_code == 200
- assert response.content == b"Hello, world!"
+ async with HTTPConnection(
+ origin="https://127.0.0.1:8001/", backend=backend
+ ) as conn:
+ response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False)
+ await response.read()
+ assert response.status_code == 200
+ assert response.content == b"Hello, world!"
-import asyncio
-import functools
-
import pytest
import httpx
-def threadpool(func):
- """
- Our sync tests should run in separate thread to the uvicorn server.
- """
-
- @functools.wraps(func)
- async def wrapped(*args, **kwargs):
- nonlocal func
-
- loop = asyncio.get_event_loop()
- if kwargs:
- func = functools.partial(func, **kwargs)
- await loop.run_in_executor(None, func, *args)
-
- return pytest.mark.asyncio(wrapped)
-
-
-@threadpool
def test_get(server):
response = httpx.get("http://127.0.0.1:8000/")
assert response.status_code == 200
assert response.http_version == "HTTP/1.1"
-@threadpool
def test_post(server):
response = httpx.post("http://127.0.0.1:8000/", data=b"Hello, world!")
assert response.status_code == 200
assert response.reason_phrase == "OK"
-@threadpool
def test_post_byte_iterator(server):
def data():
yield b"Hello"
assert response.reason_phrase == "OK"
-@threadpool
def test_options(server):
response = httpx.options("http://127.0.0.1:8000/")
assert response.status_code == 200
assert response.reason_phrase == "OK"
-@threadpool
def test_head(server):
response = httpx.head("http://127.0.0.1:8000/")
assert response.status_code == 200
assert response.reason_phrase == "OK"
-@threadpool
def test_put(server):
response = httpx.put("http://127.0.0.1:8000/", data=b"Hello, world!")
assert response.status_code == 200
assert response.reason_phrase == "OK"
-@threadpool
def test_patch(server):
response = httpx.patch("http://127.0.0.1:8000/", data=b"Hello, world!")
assert response.status_code == 200
assert response.reason_phrase == "OK"
-@threadpool
def test_delete(server):
response = httpx.delete("http://127.0.0.1:8000/")
assert response.status_code == 200
assert response.reason_phrase == "OK"
-@threadpool
def test_get_invalid_url(server):
with pytest.raises(httpx.InvalidURL):
httpx.get("invalid://example.org")
timeout = TimeoutConfig(5)
stream = await backend.connect("127.0.0.1", 8001, None, timeout)
- assert stream.is_connection_dropped() is False
- assert stream.stream_writer.get_extra_info("cipher", default=None) is None
- stream = await backend.start_tls(stream, "127.0.0.1", ctx, timeout)
- assert stream.is_connection_dropped() is False
- assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
+ try:
+ assert stream.is_connection_dropped() is False
+ assert stream.stream_writer.get_extra_info("cipher", default=None) is None
- await stream.write(b"GET / HTTP/1.1\r\n\r\n")
- assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")
+ stream = await backend.start_tls(stream, "127.0.0.1", ctx, timeout)
+ assert stream.is_connection_dropped() is False
+ assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
+
+ await stream.write(b"GET / HTTP/1.1\r\n\r\n")
+ assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")
+
+ finally:
+ await stream.close()