]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add SyncConnectionPool
authorTom Christie <tom@tomchristie.com>
Tue, 23 Apr 2019 10:12:37 +0000 (11:12 +0100)
committerTom Christie <tom@tomchristie.com>
Tue, 23 Apr 2019 10:12:37 +0000 (11:12 +0100)
httpcore/__init__.py
httpcore/compat.py [new file with mode: 0644]
httpcore/datastructures.py
httpcore/pool.py
httpcore/sync.py
tests/test_sync.py [new file with mode: 0644]

index 32ff66d6cb1f6cd447b440de22ad80baef357d0c..45bf54e131a94a0ebde5905d2095188e336a505e 100644 (file)
@@ -11,5 +11,6 @@ from .exceptions import (
     Timeout,
 )
 from .pool import ConnectionPool
+from .sync import SyncClient, SyncConnectionPool
 
 __version__ = "0.1.1"
diff --git a/httpcore/compat.py b/httpcore/compat.py
new file mode 100644 (file)
index 0000000..a16ebcc
--- /dev/null
@@ -0,0 +1,51 @@
+import asyncio
+
+if hasattr(asyncio, "run"):
+    asyncio_run = asyncio.run
+else:  # pragma: nocover
+
+    def asyncio_run(main, *, debug=False):  # type: ignore
+        if asyncio._get_running_loop() is not None:
+            raise RuntimeError(
+                "asyncio.run() cannot be called from a running event loop"
+            )
+
+        if not asyncio.iscoroutine(main):
+            raise ValueError("a coroutine was expected, got {!r}".format(main))
+
+        loop = asyncio.new_event_loop()
+        try:
+            asyncio.set_event_loop(loop)
+            loop.set_debug(debug)
+            return loop.run_until_complete(main)
+        finally:
+            try:
+                _cancel_all_tasks(loop)
+                loop.run_until_complete(loop.shutdown_asyncgens())
+            finally:
+                asyncio.set_event_loop(None)
+                loop.close()
+
+    def _cancel_all_tasks(loop):  # type: ignore
+        to_cancel = asyncio.all_tasks(loop)
+        if not to_cancel:
+            return
+
+        for task in to_cancel:
+            task.cancel()
+
+        loop.run_until_complete(
+            tasks.gather(*to_cancel, loop=loop, return_exceptions=True)
+        )
+
+        for task in to_cancel:
+            if task.cancelled():
+                continue
+            if task.exception() is not None:
+                loop.call_exception_handler(
+                    {
+                        "message": "unhandled exception during asyncio.run() shutdown",
+                        "exception": task.exception(),
+                        "task": task,
+                    }
+                )
index 7389d451d8a59b60f4f7929ed5b3c71e6ed124b2..016ae8a9869d41ac1f77cbf499ab4ce251df0822 100644 (file)
@@ -101,7 +101,7 @@ class Request:
         headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
         body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
     ):
-        self.method = method
+        self.method = method.upper()
         self.url = URL(url) if isinstance(url, str) else url
         self.headers = list(headers)
         if isinstance(body, bytes):
index a13657185a436328a304453590b405660abed9f6..a61e67c1badf2c88d7ef6f9f97aef078f64b6424 100644 (file)
@@ -106,13 +106,21 @@ class ConnectionPool(Client):
 
 class ConnectionSemaphore:
     def __init__(self, max_connections: int = None):
-        if max_connections is not None:
-            self.semaphore = asyncio.BoundedSemaphore(value=max_connections)
+        self.max_connections = max_connections
+
+    @property
+    def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
+        if not hasattr(self, "_semaphore"):
+            if self.max_connections is None:
+                self._semaphore = None
+            else:
+                self._semaphore = asyncio.BoundedSemaphore(value=self.max_connections)
+        return self._semaphore
 
     async def acquire(self) -> None:
-        if hasattr(self, "semaphore"):
+        if self.semaphore is not None:
             await self.semaphore.acquire()
 
     def release(self) -> None:
-        if hasattr(self, "semaphore"):
+        if self.semaphore is not None:
             self.semaphore.release()
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..ac2295c90d764d5af2250280b1ac3a89ec7af393 100644 (file)
@@ -0,0 +1,79 @@
+import typing
+from types import TracebackType
+
+from .compat import asyncio_run
+from .config import SSLConfig, TimeoutConfig
+from .datastructures import URL, Client, Response
+from .pool import ConnectionPool
+
+
+class SyncResponse:
+    def __init__(self, response: Response):
+        self._response = response
+
+    @property
+    def status_code(self) -> int:
+        return self._response.status_code
+
+    @property
+    def reason(self) -> str:
+        return self._response.reason
+
+    @property
+    def headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
+        return self._response.headers
+
+    @property
+    def body(self) -> bytes:
+        return self._response.body
+
+    def read(self) -> bytes:
+        return asyncio_run(self._response.read())
+
+
+class SyncClient:
+    def __init__(self, client: Client):
+        self._client = client
+
+    def request(
+        self,
+        method: str,
+        url: typing.Union[str, URL],
+        *,
+        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+        body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
+        ssl: typing.Optional[SSLConfig] = None,
+        timeout: typing.Optional[TimeoutConfig] = None,
+        stream: bool = False,
+    ) -> SyncResponse:
+        response = asyncio_run(
+            self._client.request(
+                method,
+                url,
+                headers=headers,
+                body=body,
+                ssl=ssl,
+                timeout=timeout,
+                stream=stream,
+            )
+        )
+        return SyncResponse(response)
+
+    def close(self) -> None:
+        asyncio_run(self._client.close())
+
+    def __enter__(self) -> "SyncClient":
+        return self
+
+    def __exit__(
+        self,
+        exc_type: typing.Type[BaseException] = None,
+        exc_value: BaseException = None,
+        traceback: TracebackType = None,
+    ) -> None:
+        self.close()
+
+
+def SyncConnectionPool(*args: typing.Any, **kwargs: typing.Any) -> SyncClient:
+    client = ConnectionPool(*args, **kwargs)  # type: ignore
+    return SyncClient(client)
diff --git a/tests/test_sync.py b/tests/test_sync.py
new file mode 100644 (file)
index 0000000..7768cbc
--- /dev/null
@@ -0,0 +1,38 @@
+import asyncio
+import functools
+
+import pytest
+
+import httpcore
+
+
+def threadpool(func):
+    """
+    Our sync tests should run in seperate thread to the uvicorn server.
+    """
+
+    @functools.wraps(func)
+    async def wrapped(*args, **kwargs):
+        nonlocal func
+
+        loop = asyncio.get_event_loop()
+        if kwargs:
+            func = functools.partial(func, **kwargs)
+        await loop.run_in_executor(None, func, *args)
+
+    return pytest.mark.asyncio(wrapped)
+
+
+@threadpool
+def test_get(server):
+    with httpcore.SyncConnectionPool() as http:
+        response = http.request("GET", "http://127.0.0.1:8000/")
+    assert response.status_code == 200
+    assert response.body == b"Hello, world!"
+
+
+@threadpool
+def test_post(server):
+    with httpcore.SyncConnectionPool() as http:
+        response = http.request("POST", "http://127.0.0.1:8000/", body=b"Hello, world!")
+    assert response.status_code == 200