]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Sync or Async dispatch (#83)
authorTom Christie <tom@tomchristie.com>
Mon, 10 Jun 2019 11:26:03 +0000 (12:26 +0100)
committerGitHub <noreply@github.com>
Mon, 10 Jun 2019 11:26:03 +0000 (12:26 +0100)
* Support thread-pooled dispatch

* Add ConcurrencyBackend.run

* Initial work towards support byte-iterators on sync request data

* Test case for byte iterator content

* byte iterator support for RequestData

* Add BaseResponse

* Bridge sync/async data in SyncResponse

* Add BaseClient

* SyncResponse -> Response

* Tweaking type annotation

* Distinct classes for Request, AsyncRequest

* Tweak is_streaming, content in BaseRequest

* Stream handling moves to client

* Handle mediating to AsyncResponse from a standard sync Dispatcher class

* Working on thread-pooled dispatcher

* Support threaded dispatch, inc. streaming requests/responses

* Increase test coverage

* Coverage and tweaks

* Include Accept and User-Agent headers by default

22 files changed:
httpcore/__init__.py
httpcore/api.py
httpcore/auth.py
httpcore/client.py
httpcore/concurrency.py
httpcore/dispatch/connection.py
httpcore/dispatch/connection_pool.py
httpcore/dispatch/http11.py
httpcore/dispatch/http2.py
httpcore/dispatch/threaded.py [new file with mode: 0644]
httpcore/interfaces.py
httpcore/models.py
tests/client/test_auth.py
tests/client/test_cookies.py
tests/client/test_redirects.py
tests/dispatch/test_connection_pools.py
tests/dispatch/test_connections.py
tests/dispatch/test_threaded.py [new file with mode: 0644]
tests/models/test_requests.py
tests/models/test_responses.py
tests/test_api.py
tests/test_decoders.py

index 385c35a7ffa5029cc028353e0e492ce5fb64a6de..8f9cb07c03a13d0b0e568d701755ae0615d82707 100644 (file)
@@ -28,8 +28,25 @@ from .exceptions import (
     TooManyRedirects,
     WriteTimeout,
 )
-from .interfaces import BaseReader, BaseWriter, ConcurrencyBackend, Dispatcher, Protocol
-from .models import URL, Cookies, Headers, Origin, QueryParams, Request, Response
+from .interfaces import (
+    AsyncDispatcher,
+    BaseReader,
+    BaseWriter,
+    ConcurrencyBackend,
+    Dispatcher,
+    Protocol,
+)
+from .models import (
+    URL,
+    AsyncRequest,
+    AsyncResponse,
+    Cookies,
+    Headers,
+    Origin,
+    QueryParams,
+    Request,
+    Response,
+)
 from .status_codes import StatusCode, codes
 
 __version__ = "0.4.0"
index 33d68c5e77832fcb66558bb9d667c44b3700015b..7e2682567f7dc387bbf726ffe67729c24b026764 100644 (file)
@@ -8,7 +8,7 @@ from .models import (
     HeaderTypes,
     QueryParamTypes,
     RequestData,
-    SyncResponse,
+    Response,
     URLTypes,
 )
 
@@ -30,7 +30,7 @@ def request(
     cert: CertTypes = None,
     verify: VerifyTypes = True,
     stream: bool = False,
-) -> SyncResponse:
+) -> Response:
     with Client() as client:
         return client.request(
             method=method,
@@ -61,7 +61,7 @@ def get(
     cert: CertTypes = None,
     verify: VerifyTypes = True,
     timeout: TimeoutTypes = None,
-) -> SyncResponse:
+) -> Response:
     return request(
         "GET",
         url,
@@ -88,7 +88,7 @@ def options(
     cert: CertTypes = None,
     verify: VerifyTypes = True,
     timeout: TimeoutTypes = None,
-) -> SyncResponse:
+) -> Response:
     return request(
         "OPTIONS",
         url,
@@ -115,7 +115,7 @@ def head(
     cert: CertTypes = None,
     verify: VerifyTypes = True,
     timeout: TimeoutTypes = None,
-) -> SyncResponse:
+) -> Response:
     return request(
         "HEAD",
         url,
@@ -144,7 +144,7 @@ def post(
     cert: CertTypes = None,
     verify: VerifyTypes = True,
     timeout: TimeoutTypes = None,
-) -> SyncResponse:
+) -> Response:
     return request(
         "POST",
         url,
@@ -175,7 +175,7 @@ def put(
     cert: CertTypes = None,
     verify: VerifyTypes = True,
     timeout: TimeoutTypes = None,
-) -> SyncResponse:
+) -> Response:
     return request(
         "PUT",
         url,
@@ -206,7 +206,7 @@ def patch(
     cert: CertTypes = None,
     verify: VerifyTypes = True,
     timeout: TimeoutTypes = None,
-) -> SyncResponse:
+) -> Response:
     return request(
         "PATCH",
         url,
@@ -237,7 +237,7 @@ def delete(
     cert: CertTypes = None,
     verify: VerifyTypes = True,
     timeout: TimeoutTypes = None,
-) -> SyncResponse:
+) -> Response:
     return request(
         "DELETE",
         url,
index 49ff998b430a1e9285b18ce3b49bc9275e67e3cd..6a39c1b2c5661979ee1b3134426dacc6c33f57e5 100644 (file)
@@ -1,7 +1,7 @@
 import typing
 from base64 import b64encode
 
-from .models import Request
+from .models import AsyncRequest
 
 
 class AuthBase:
@@ -9,7 +9,7 @@ class AuthBase:
     Base class that all auth implementations derive from.
     """
 
-    def __call__(self, request: Request) -> Request:
+    def __call__(self, request: AsyncRequest) -> AsyncRequest:
         raise NotImplementedError("Auth hooks must be callable.")  # pragma: nocover
 
 
@@ -20,7 +20,7 @@ class HTTPBasicAuth(AuthBase):
         self.username = username
         self.password = password
 
-    def __call__(self, request: Request) -> Request:
+    def __call__(self, request: AsyncRequest) -> AsyncRequest:
         request.headers["Authorization"] = self.build_auth_header()
         return request
 
index 2946a753fdb9f84bf844126e9ceb5f2b2ab77de9..0fc60a1a3ba843bf6b38c6e454bd4f3ff78d6a20 100644 (file)
@@ -1,8 +1,8 @@
-import asyncio
 import typing
 from types import TracebackType
 
 from .auth import HTTPBasicAuth
+from .concurrency import AsyncioBackend
 from .config import (
     DEFAULT_MAX_REDIRECTS,
     DEFAULT_POOL_LIMITS,
@@ -13,10 +13,15 @@ from .config import (
     VerifyTypes,
 )
 from .dispatch.connection_pool import ConnectionPool
+from .dispatch.threaded import ThreadedDispatcher
 from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
-from .interfaces import ConcurrencyBackend, Dispatcher
+from .interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher
 from .models import (
     URL,
+    AsyncRequest,
+    AsyncRequestData,
+    AsyncResponse,
+    AsyncResponseContent,
     AuthTypes,
     Cookies,
     CookieTypes,
@@ -26,13 +31,13 @@ from .models import (
     Request,
     RequestData,
     Response,
-    SyncResponse,
+    ResponseContent,
     URLTypes,
 )
 from .status_codes import codes
 
 
-class AsyncClient:
+class BaseClient:
     def __init__(
         self,
         auth: AuthTypes = None,
@@ -42,23 +47,208 @@ class AsyncClient:
         timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
         max_redirects: int = DEFAULT_MAX_REDIRECTS,
-        dispatch: Dispatcher = None,
+        dispatch: typing.Union[AsyncDispatcher, Dispatcher] = None,
         backend: ConcurrencyBackend = None,
     ):
+        if backend is None:
+            backend = AsyncioBackend()
+
         if dispatch is None:
-            dispatch = ConnectionPool(
+            async_dispatch = ConnectionPool(
                 verify=verify,
                 cert=cert,
                 timeout=timeout,
                 pool_limits=pool_limits,
                 backend=backend,
-            )
+            )  # type: AsyncDispatcher
+        elif isinstance(dispatch, Dispatcher):
+            async_dispatch = ThreadedDispatcher(dispatch, backend)
+        else:
+            async_dispatch = dispatch
 
         self.auth = auth
         self.cookies = Cookies(cookies)
         self.max_redirects = max_redirects
-        self.dispatch = dispatch
+        self.dispatch = async_dispatch
+        self.concurrency_backend = backend
+
+    def merge_cookies(
+        self, cookies: CookieTypes = None
+    ) -> typing.Optional[CookieTypes]:
+        if cookies or self.cookies:
+            merged_cookies = Cookies(self.cookies)
+            merged_cookies.update(cookies)
+            return merged_cookies
+        return cookies
+
+    async def send(
+        self,
+        request: AsyncRequest,
+        *,
+        stream: bool = False,
+        auth: AuthTypes = None,
+        allow_redirects: bool = True,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
+    ) -> AsyncResponse:
+        if auth is None:
+            auth = self.auth
+
+        url = request.url
+        if auth is None and (url.username or url.password):
+            auth = HTTPBasicAuth(username=url.username, password=url.password)
+
+        if auth is not None:
+            if isinstance(auth, tuple):
+                auth = HTTPBasicAuth(username=auth[0], password=auth[1])
+            request = auth(request)
+
+        response = await self.send_handling_redirects(
+            request,
+            verify=verify,
+            cert=cert,
+            timeout=timeout,
+            allow_redirects=allow_redirects,
+        )
+
+        if not stream:
+            try:
+                await response.read()
+            finally:
+                await response.close()
+
+        return response
+
+    async def send_handling_redirects(
+        self,
+        request: AsyncRequest,
+        *,
+        cert: CertTypes = None,
+        verify: VerifyTypes = None,
+        timeout: TimeoutTypes = None,
+        allow_redirects: bool = True,
+        history: typing.List[AsyncResponse] = None,
+    ) -> AsyncResponse:
+        if history is None:
+            history = []
+
+        while True:
+            # We perform these checks here, so that calls to `response.next()`
+            # will raise redirect errors if appropriate.
+            if len(history) > self.max_redirects:
+                raise TooManyRedirects()
+            if request.url in [response.url for response in history]:
+                raise RedirectLoop()
+
+            response = await self.dispatch.send(
+                request, verify=verify, cert=cert, timeout=timeout
+            )
+            assert isinstance(response, AsyncResponse)
+            response.history = list(history)
+            self.cookies.extract_cookies(response)
+            history = [response] + history
+            if not response.is_redirect:
+                break
+
+            if allow_redirects:
+                request = self.build_redirect_request(request, response)
+            else:
+
+                async def send_next() -> AsyncResponse:
+                    nonlocal request, response, verify, cert, allow_redirects, timeout, history
+                    request = self.build_redirect_request(request, response)
+                    response = await self.send_handling_redirects(
+                        request,
+                        allow_redirects=allow_redirects,
+                        verify=verify,
+                        cert=cert,
+                        timeout=timeout,
+                        history=history,
+                    )
+                    return response
+
+                response.next = send_next  # type: ignore
+                break
+
+        return response
+
+    def build_redirect_request(
+        self, request: AsyncRequest, response: AsyncResponse
+    ) -> AsyncRequest:
+        method = self.redirect_method(request, response)
+        url = self.redirect_url(request, response)
+        headers = self.redirect_headers(request, url)
+        content = self.redirect_content(request, method)
+        cookies = self.merge_cookies(request.cookies)
+        return AsyncRequest(
+            method=method, url=url, headers=headers, data=content, cookies=cookies
+        )
+
+    def redirect_method(self, request: AsyncRequest, response: AsyncResponse) -> str:
+        """
+        When being redirected we may want to change the method of the request
+        based on certain specs or browser behavior.
+        """
+        method = request.method
+
+        # https://tools.ietf.org/html/rfc7231#section-6.4.4
+        if response.status_code == codes.SEE_OTHER and method != "HEAD":
+            method = "GET"
+
+        # Do what the browsers do, despite standards...
+        # Turn 302s into GETs.
+        if response.status_code == codes.FOUND and method != "HEAD":
+            method = "GET"
+
+        # If a POST is responded to with a 301, turn it into a GET.
+        # This bizarre behaviour is explained in 'requests' issue 1704.
+        if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
+            method = "GET"
+
+        return method
+
+    def redirect_url(self, request: AsyncRequest, response: AsyncResponse) -> URL:
+        """
+        Return the URL for the redirect to follow.
+        """
+        location = response.headers["Location"]
+
+        url = URL(location, allow_relative=True)
+
+        # Facilitate relative 'Location' headers, as allowed by RFC 7231.
+        # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
+        if url.is_relative_url:
+            url = url.resolve_with(request.url)
+
+        # Attach previous fragment if needed (RFC 7231 7.1.2)
+        if request.url.fragment and not url.fragment:
+            url = url.copy_with(fragment=request.url.fragment)
+
+        return url
+
+    def redirect_headers(self, request: AsyncRequest, url: URL) -> Headers:
+        """
+        Strip Authorization headers when responses are redirected away from
+        the origin.
+        """
+        headers = Headers(request.headers)
+        if url.origin != request.url.origin:
+            del headers["Authorization"]
+        return headers
+
+    def redirect_content(self, request: AsyncRequest, method: str) -> bytes:
+        """
+        Return the body that should be used for the redirect request.
+        """
+        if method != request.method and method == "GET":
+            return b""
+        if request.is_streaming:
+            raise RedirectBodyUnavailable()
+        return request.content
 
+
+class AsyncClient(BaseClient):
     async def get(
         self,
         url: URLTypes,
@@ -72,7 +262,7 @@ class AsyncClient:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> Response:
+    ) -> AsyncResponse:
         return await self.request(
             "GET",
             url,
@@ -100,7 +290,7 @@ class AsyncClient:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> Response:
+    ) -> AsyncResponse:
         return await self.request(
             "OPTIONS",
             url,
@@ -128,7 +318,7 @@ class AsyncClient:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> Response:
+    ) -> AsyncResponse:
         return await self.request(
             "HEAD",
             url,
@@ -147,7 +337,7 @@ class AsyncClient:
         self,
         url: URLTypes,
         *,
-        data: RequestData = b"",
+        data: AsyncRequestData = b"",
         json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
@@ -158,7 +348,7 @@ class AsyncClient:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> Response:
+    ) -> AsyncResponse:
         return await self.request(
             "POST",
             url,
@@ -179,7 +369,7 @@ class AsyncClient:
         self,
         url: URLTypes,
         *,
-        data: RequestData = b"",
+        data: AsyncRequestData = b"",
         json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
@@ -190,7 +380,7 @@ class AsyncClient:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> Response:
+    ) -> AsyncResponse:
         return await self.request(
             "PUT",
             url,
@@ -211,7 +401,7 @@ class AsyncClient:
         self,
         url: URLTypes,
         *,
-        data: RequestData = b"",
+        data: AsyncRequestData = b"",
         json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
@@ -222,7 +412,7 @@ class AsyncClient:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> Response:
+    ) -> AsyncResponse:
         return await self.request(
             "PATCH",
             url,
@@ -243,7 +433,7 @@ class AsyncClient:
         self,
         url: URLTypes,
         *,
-        data: RequestData = b"",
+        data: AsyncRequestData = b"",
         json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
@@ -254,7 +444,7 @@ class AsyncClient:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> Response:
+    ) -> AsyncResponse:
         return await self.request(
             "DELETE",
             url,
@@ -276,7 +466,7 @@ class AsyncClient:
         method: str,
         url: URLTypes,
         *,
-        data: RequestData = b"",
+        data: AsyncRequestData = b"",
         json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
@@ -287,8 +477,8 @@ class AsyncClient:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> Response:
-        request = Request(
+    ) -> AsyncResponse:
+        request = AsyncRequest(
             method,
             url,
             data=data,
@@ -308,174 +498,6 @@ class AsyncClient:
         )
         return response
 
-    def merge_cookies(
-        self, cookies: CookieTypes = None
-    ) -> typing.Optional[CookieTypes]:
-        if cookies or self.cookies:
-            merged_cookies = Cookies(self.cookies)
-            merged_cookies.update(cookies)
-            return merged_cookies
-        return cookies
-
-    async def send(
-        self,
-        request: Request,
-        *,
-        stream: bool = False,
-        auth: AuthTypes = None,
-        allow_redirects: bool = True,
-        verify: VerifyTypes = None,
-        cert: CertTypes = None,
-        timeout: TimeoutTypes = None,
-    ) -> Response:
-        if auth is None:
-            auth = self.auth
-
-        url = request.url
-        if auth is None and (url.username or url.password):
-            auth = HTTPBasicAuth(username=url.username, password=url.password)
-
-        if auth is not None:
-            if isinstance(auth, tuple):
-                auth = HTTPBasicAuth(username=auth[0], password=auth[1])
-            request = auth(request)
-
-        response = await self.send_handling_redirects(
-            request,
-            stream=stream,
-            verify=verify,
-            cert=cert,
-            timeout=timeout,
-            allow_redirects=allow_redirects,
-        )
-        return response
-
-    async def send_handling_redirects(
-        self,
-        request: Request,
-        *,
-        stream: bool = False,
-        cert: CertTypes = None,
-        verify: VerifyTypes = None,
-        timeout: TimeoutTypes = None,
-        allow_redirects: bool = True,
-        history: typing.List[Response] = None,
-    ) -> Response:
-        if history is None:
-            history = []
-
-        while True:
-            # We perform these checks here, so that calls to `response.next()`
-            # will raise redirect errors if appropriate.
-            if len(history) > self.max_redirects:
-                raise TooManyRedirects()
-            if request.url in [response.url for response in history]:
-                raise RedirectLoop()
-
-            response = await self.dispatch.send(
-                request, stream=stream, verify=verify, cert=cert, timeout=timeout
-            )
-            response.history = list(history)
-            self.cookies.extract_cookies(response)
-            history = [response] + history
-            if not response.is_redirect:
-                break
-
-            if allow_redirects:
-                request = self.build_redirect_request(request, response)
-            else:
-
-                async def send_next() -> Response:
-                    nonlocal request, response, verify, cert, allow_redirects, timeout, history
-                    request = self.build_redirect_request(request, response)
-                    response = await self.send_handling_redirects(
-                        request,
-                        stream=stream,
-                        allow_redirects=allow_redirects,
-                        verify=verify,
-                        cert=cert,
-                        timeout=timeout,
-                        history=history,
-                    )
-                    return response
-
-                response.next = send_next  # type: ignore
-                break
-
-        return response
-
-    def build_redirect_request(self, request: Request, response: Response) -> Request:
-        method = self.redirect_method(request, response)
-        url = self.redirect_url(request, response)
-        headers = self.redirect_headers(request, url)
-        content = self.redirect_content(request, method)
-        cookies = self.merge_cookies(request.cookies)
-        return Request(
-            method=method, url=url, headers=headers, data=content, cookies=cookies
-        )
-
-    def redirect_method(self, request: Request, response: Response) -> str:
-        """
-        When being redirected we may want to change the method of the request
-        based on certain specs or browser behavior.
-        """
-        method = request.method
-
-        # https://tools.ietf.org/html/rfc7231#section-6.4.4
-        if response.status_code == codes.SEE_OTHER and method != "HEAD":
-            method = "GET"
-
-        # Do what the browsers do, despite standards...
-        # Turn 302s into GETs.
-        if response.status_code == codes.FOUND and method != "HEAD":
-            method = "GET"
-
-        # If a POST is responded to with a 301, turn it into a GET.
-        # This bizarre behaviour is explained in 'requests' issue 1704.
-        if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
-            method = "GET"
-
-        return method
-
-    def redirect_url(self, request: Request, response: Response) -> URL:
-        """
-        Return the URL for the redirect to follow.
-        """
-        location = response.headers["Location"]
-
-        url = URL(location, allow_relative=True)
-
-        # Facilitate relative 'Location' headers, as allowed by RFC 7231.
-        # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
-        if url.is_relative_url:
-            url = url.resolve_with(request.url)
-
-        # Attach previous fragment if needed (RFC 7231 7.1.2)
-        if request.url.fragment and not url.fragment:
-            url = url.copy_with(fragment=request.url.fragment)
-
-        return url
-
-    def redirect_headers(self, request: Request, url: URL) -> Headers:
-        """
-        Strip Authorization headers when responses are redirected away from
-        the origin.
-        """
-        headers = Headers(request.headers)
-        if url.origin != request.url.origin:
-            del headers["Authorization"]
-        return headers
-
-    def redirect_content(self, request: Request, method: str) -> bytes:
-        """
-        Return the body that should be used for the redirect request.
-        """
-        if method != request.method and method == "GET":
-            return b""
-        if request.is_streaming:
-            raise RedirectBodyUnavailable()
-        return request.content
-
     async def close(self) -> None:
         await self.dispatch.close()
 
@@ -491,33 +513,28 @@ class AsyncClient:
         await self.close()
 
 
-class Client:
-    def __init__(
-        self,
-        auth: AuthTypes = None,
-        cert: CertTypes = None,
-        verify: VerifyTypes = True,
-        timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
-        pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
-        max_redirects: int = DEFAULT_MAX_REDIRECTS,
-        dispatch: Dispatcher = None,
-        backend: ConcurrencyBackend = None,
-    ) -> None:
-        self._client = AsyncClient(
-            auth=auth,
-            verify=verify,
-            cert=cert,
-            timeout=timeout,
-            pool_limits=pool_limits,
-            max_redirects=max_redirects,
-            dispatch=dispatch,
-            backend=backend,
-        )
-        self._loop = asyncio.new_event_loop()
+class Client(BaseClient):
+    def _async_request_data(self, data: RequestData) -> AsyncRequestData:
+        """
+        If the request data is an bytes iterator then return an async bytes
+        iterator onto the request data.
+        """
+        if isinstance(data, (bytes, dict)):
+            return data
+
+        # Coerce an iterator into an async iterator, with each item in the
+        # iteration running as a thread-pooled operation.
+        assert hasattr(data, "__iter__")
+        return self.concurrency_backend.iterate_in_threadpool(data)
 
-    @property
-    def cookies(self) -> Cookies:
-        return self._client.cookies
+    def _sync_data(self, data: AsyncResponseContent) -> ResponseContent:
+        if isinstance(data, bytes):
+            return data
+
+        # Coerce an async iterator into an iterator, with each item in the
+        # iteration run within the event loop.
+        assert hasattr(data, "__aiter__")
+        return self.concurrency_backend.iterate(data)
 
     def request(
         self,
@@ -535,25 +552,55 @@ class Client:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> SyncResponse:
-        request = Request(
+    ) -> Response:
+        request = AsyncRequest(
             method,
             url,
-            data=data,
+            data=self._async_request_data(data),
             json=json,
             params=params,
             headers=headers,
-            cookies=self._client.merge_cookies(cookies),
+            cookies=self.merge_cookies(cookies),
         )
-        response = self.send(
-            request,
-            stream=stream,
+        concurrency_backend = self.concurrency_backend
+
+        coroutine = self.send
+        args = [request]
+        kwargs = dict(
+            stream=True,
             auth=auth,
             allow_redirects=allow_redirects,
             verify=verify,
             cert=cert,
             timeout=timeout,
         )
+        async_response = concurrency_backend.run(coroutine, *args, **kwargs)
+
+        content = getattr(
+            async_response, "_raw_content", getattr(async_response, "_raw_stream", None)
+        )
+
+        sync_content = self._sync_data(content)
+
+        def sync_on_close() -> None:
+            nonlocal concurrency_backend, async_response
+            concurrency_backend.run(async_response.on_close)
+
+        response = Response(
+            status_code=async_response.status_code,
+            reason_phrase=async_response.reason_phrase,
+            protocol=async_response.protocol,
+            headers=async_response.headers,
+            content=sync_content,
+            on_close=sync_on_close,
+            request=async_response.request,
+            history=async_response.history,
+        )
+        if not stream:
+            try:
+                response.read()
+            finally:
+                response.close()
         return response
 
     def get(
@@ -569,7 +616,7 @@ class Client:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> SyncResponse:
+    ) -> Response:
         return self.request(
             "GET",
             url,
@@ -596,7 +643,7 @@ class Client:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> SyncResponse:
+    ) -> Response:
         return self.request(
             "OPTIONS",
             url,
@@ -623,7 +670,7 @@ class Client:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> SyncResponse:
+    ) -> Response:
         return self.request(
             "HEAD",
             url,
@@ -652,7 +699,7 @@ class Client:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> SyncResponse:
+    ) -> Response:
         return self.request(
             "POST",
             url,
@@ -683,7 +730,7 @@ class Client:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> SyncResponse:
+    ) -> Response:
         return self.request(
             "PUT",
             url,
@@ -714,7 +761,7 @@ class Client:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> SyncResponse:
+    ) -> Response:
         return self.request(
             "PATCH",
             url,
@@ -745,7 +792,7 @@ class Client:
         cert: CertTypes = None,
         verify: VerifyTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> SyncResponse:
+    ) -> Response:
         return self.request(
             "DELETE",
             url,
@@ -761,32 +808,9 @@ class Client:
             timeout=timeout,
         )
 
-    def send(
-        self,
-        request: Request,
-        *,
-        stream: bool = False,
-        auth: AuthTypes = None,
-        allow_redirects: bool = True,
-        verify: VerifyTypes = None,
-        cert: CertTypes = None,
-        timeout: TimeoutTypes = None,
-    ) -> SyncResponse:
-        response = self._loop.run_until_complete(
-            self._client.send(
-                request,
-                stream=stream,
-                auth=auth,
-                allow_redirects=allow_redirects,
-                verify=verify,
-                cert=cert,
-                timeout=timeout,
-            )
-        )
-        return SyncResponse(response, self._loop)
-
     def close(self) -> None:
-        self._loop.run_until_complete(self._client.close())
+        coroutine = self.dispatch.close
+        self.concurrency_backend.run(coroutine)
 
     def __enter__(self) -> "Client":
         return self
index 0c1d3409eb8152ea2fbeb49e9ece0762ccbbae93..664cb294484f03594cf0c04cbc0a052daedb8b4e 100644 (file)
@@ -9,6 +9,7 @@ protocols, and help keep the rest of the package more `async`/`await`
 based, and less strictly `asyncio`-specific.
 """
 import asyncio
+import functools
 import ssl
 import typing
 
@@ -133,6 +134,15 @@ class AsyncioBackend(ConcurrencyBackend):
             ssl_monkey_patch()
         SSL_MONKEY_PATCH_APPLIED = True
 
+    @property
+    def loop(self) -> asyncio.AbstractEventLoop:
+        if not hasattr(self, "_loop"):
+            try:
+                self._loop = asyncio.get_event_loop()
+            except RuntimeError:
+                self._loop = asyncio.new_event_loop()
+        return self._loop
+
     async def connect(
         self,
         hostname: str,
@@ -162,5 +172,24 @@ class AsyncioBackend(ConcurrencyBackend):
 
         return (reader, writer, protocol)
 
+    async def run_in_threadpool(
+        self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
+    ) -> typing.Any:
+        if kwargs:
+            # loop.run_in_executor doesn't accept 'kwargs', so bind them in here
+            func = functools.partial(func, **kwargs)
+        return await self.loop.run_in_executor(None, func, *args)
+
+    def run(
+        self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
+    ) -> typing.Any:
+        loop = self.loop
+        if loop.is_running():
+            self._loop = asyncio.new_event_loop()
+        try:
+            return self.loop.run_until_complete(coroutine(*args, **kwargs))
+        finally:
+            self._loop = loop
+
     def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
         return PoolSemaphore(limits)
index 60214333fe990599bc044d662637071a11713a4e..d644bcba73b32094bebbdda0da8bf1396ad88cd9 100644 (file)
@@ -15,8 +15,8 @@ from ..config import (
     VerifyTypes,
 )
 from ..exceptions import ConnectTimeout
-from ..interfaces import ConcurrencyBackend, Dispatcher, Protocol
-from ..models import Origin, Request, Response
+from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Protocol
+from ..models import AsyncRequest, AsyncResponse, Origin
 from .http2 import HTTP2Connection
 from .http11 import HTTP11Connection
 
@@ -24,7 +24,7 @@ from .http11 import HTTP11Connection
 ReleaseCallback = typing.Callable[["HTTPConnection"], typing.Awaitable[None]]
 
 
-class HTTPConnection(Dispatcher):
+class HTTPConnection(AsyncDispatcher):
     def __init__(
         self,
         origin: typing.Union[str, Origin],
@@ -44,24 +44,19 @@ class HTTPConnection(Dispatcher):
 
     async def send(
         self,
-        request: Request,
-        stream: bool = False,
+        request: AsyncRequest,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> Response:
+    ) -> AsyncResponse:
         if self.h11_connection is None and self.h2_connection is None:
             await self.connect(verify=verify, cert=cert, timeout=timeout)
 
         if self.h2_connection is not None:
-            response = await self.h2_connection.send(
-                request, stream=stream, timeout=timeout
-            )
+            response = await self.h2_connection.send(request, timeout=timeout)
         else:
             assert self.h11_connection is not None
-            response = await self.h11_connection.send(
-                request, stream=stream, timeout=timeout
-            )
+            response = await self.h11_connection.send(request, timeout=timeout)
 
         return response
 
index e7cefbd7e4d27677badeb25b8a30b914ecdd035c..c84117ca45bb6652437a7a7d9c083b5b0362e2e1 100644 (file)
@@ -12,8 +12,8 @@ from ..config import (
 )
 from ..decoders import ACCEPT_ENCODING
 from ..exceptions import PoolTimeout
-from ..interfaces import ConcurrencyBackend, Dispatcher
-from ..models import Origin, Request, Response
+from ..interfaces import AsyncDispatcher, ConcurrencyBackend
+from ..models import AsyncRequest, AsyncResponse, Origin
 from .connection import HTTPConnection
 
 CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
@@ -77,7 +77,7 @@ class ConnectionStore:
         return len(self.all)
 
 
-class ConnectionPool(Dispatcher):
+class ConnectionPool(AsyncDispatcher):
     def __init__(
         self,
         *,
@@ -105,16 +105,15 @@ class ConnectionPool(Dispatcher):
 
     async def send(
         self,
-        request: Request,
-        stream: bool = False,
+        request: AsyncRequest,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> Response:
+    ) -> AsyncResponse:
         connection = await self.acquire_connection(request.url.origin)
         try:
             response = await connection.send(
-                request, stream=stream, verify=verify, cert=cert, timeout=timeout
+                request, verify=verify, cert=cert, timeout=timeout
             )
         except BaseException as exc:
             self.active_connections.remove(connection)
index 6f1548563b68ca80a4bfd9fd0ec4406d2cdd5c53..f19b3d3dc160ac702af74f7940e422c0e6247c5e 100644 (file)
@@ -4,8 +4,8 @@ import h11
 
 from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
 from ..exceptions import ConnectTimeout, ReadTimeout
-from ..interfaces import BaseReader, BaseWriter, Dispatcher
-from ..models import Request, Response
+from ..interfaces import BaseReader, BaseWriter
+from ..models import AsyncRequest, AsyncResponse
 
 H11Event = typing.Union[
     h11.Request,
@@ -38,15 +38,15 @@ class HTTP11Connection:
         self.h11_state = h11.Connection(our_role=h11.CLIENT)
 
     async def send(
-        self, request: Request, stream: bool = False, timeout: TimeoutTypes = None
-    ) -> Response:
+        self, request: AsyncRequest, timeout: TimeoutTypes = None
+    ) -> AsyncResponse:
         timeout = None if timeout is None else TimeoutConfig(timeout)
 
         # Â Start sending the request.
         method = request.method.encode("ascii")
         target = request.url.full_path.encode("ascii")
         headers = request.headers.raw
-        if 'Host' not in request.headers:
+        if "Host" not in request.headers:
             host = request.url.authority.encode("ascii")
             headers = [(b"host", host)] + headers
         event = h11.Request(method=method, target=target, headers=headers)
@@ -72,7 +72,7 @@ class HTTP11Connection:
         headers = event.headers
         content = self._body_iter(timeout)
 
-        response = Response(
+        return AsyncResponse(
             status_code=status_code,
             reason_phrase=reason_phrase,
             protocol="HTTP/1.1",
@@ -82,14 +82,6 @@ class HTTP11Connection:
             request=request,
         )
 
-        if not stream:
-            try:
-                await response.read()
-            finally:
-                await response.close()
-
-        return response
-
     async def close(self) -> None:
         event = h11.ConnectionClosed()
         self.h11_state.send(event)
index 402f3b651c6deeb61b3433b273d51af1ffececa9..f7814ec3c1193908f2e89765799c80017be91abb 100644 (file)
@@ -6,8 +6,8 @@ import h2.events
 
 from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
 from ..exceptions import ConnectTimeout, ReadTimeout
-from ..interfaces import BaseReader, BaseWriter, Dispatcher
-from ..models import Request, Response
+from ..interfaces import BaseReader, BaseWriter
+from ..models import AsyncRequest, AsyncResponse
 
 
 class HTTP2Connection:
@@ -24,8 +24,8 @@ class HTTP2Connection:
         self.initialized = False
 
     async def send(
-        self, request: Request, stream: bool = False, timeout: TimeoutTypes = None
-    ) -> Response:
+        self, request: AsyncRequest, timeout: TimeoutTypes = None
+    ) -> AsyncResponse:
         timeout = None if timeout is None else TimeoutConfig(timeout)
 
         # Â Start sending the request.
@@ -59,7 +59,7 @@ class HTTP2Connection:
         content = self.body_iter(stream_id, timeout)
         on_close = functools.partial(self.response_closed, stream_id=stream_id)
 
-        response = Response(
+        return AsyncResponse(
             status_code=status_code,
             protocol="HTTP/2",
             headers=headers,
@@ -68,14 +68,6 @@ class HTTP2Connection:
             request=request,
         )
 
-        if not stream:
-            try:
-                await response.read()
-            finally:
-                await response.close()
-
-        return response
-
     async def close(self) -> None:
         await self.writer.close()
 
@@ -86,7 +78,7 @@ class HTTP2Connection:
         self.initialized = True
 
     async def send_headers(
-        self, request: Request, timeout: TimeoutConfig = None
+        self, request: AsyncRequest, timeout: TimeoutConfig = None
     ) -> int:
         stream_id = self.h2_state.get_next_available_stream_id()
         headers = [
diff --git a/httpcore/dispatch/threaded.py b/httpcore/dispatch/threaded.py
new file mode 100644 (file)
index 0000000..dbcd4da
--- /dev/null
@@ -0,0 +1,97 @@
+from ..config import CertTypes, TimeoutTypes, VerifyTypes
+from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher
+from ..models import (
+    AsyncRequest,
+    AsyncRequestData,
+    AsyncResponse,
+    AsyncResponseContent,
+    Request,
+    RequestData,
+    Response,
+    ResponseContent,
+)
+
+
+class ThreadedDispatcher(AsyncDispatcher):
+    """
+    The ThreadedDispatcher class is used to mediate between the Client
+    (which always uses async under the hood), and a synchronous `Dispatch`
+    class.
+    """
+
+    def __init__(self, dispatch: Dispatcher, backend: ConcurrencyBackend) -> None:
+        self.sync_dispatcher = dispatch
+        self.backend = backend
+
+    async def send(
+        self,
+        request: AsyncRequest,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
+    ) -> AsyncResponse:
+        concurrency_backend = self.backend
+
+        data = getattr(request, "content", getattr(request, "content_aiter", None))
+        sync_data = self._sync_request_data(data)
+
+        sync_request = Request(
+            method=request.method,
+            url=request.url,
+            headers=request.headers,
+            data=sync_data,
+        )
+
+        func = self.sync_dispatcher.send
+        kwargs = {
+            "request": sync_request,
+            "verify": verify,
+            "cert": cert,
+            "timeout": timeout,
+        }
+        sync_response = await self.backend.run_in_threadpool(func, **kwargs)
+        assert isinstance(sync_response, Response)
+
+        content = getattr(
+            sync_response, "_raw_content", getattr(sync_response, "_raw_stream", None)
+        )
+
+        async_content = self._async_response_content(content)
+
+        async def async_on_close() -> None:
+            nonlocal concurrency_backend, sync_response
+            await concurrency_backend.run_in_threadpool(sync_response.close)
+
+        return AsyncResponse(
+            status_code=sync_response.status_code,
+            reason_phrase=sync_response.reason_phrase,
+            protocol=sync_response.protocol,
+            headers=sync_response.headers,
+            content=async_content,
+            on_close=async_on_close,
+            request=request,
+            history=sync_response.history,
+        )
+
+    async def close(self) -> None:
+        """
+        The `.close()` method runs the `Dispatcher.close()` within a threadpool,
+        so as not to block the async event loop.
+        """
+        func = self.sync_dispatcher.close
+        await self.backend.run_in_threadpool(func)
+
+    def _async_response_content(self, content: ResponseContent) -> AsyncResponseContent:
+        if isinstance(content, bytes):
+            return content
+
+        # Coerce an async iterator into an iterator, with each item in the
+        # iteration run within the event loop.
+        assert hasattr(content, "__iter__")
+        return self.backend.iterate_in_threadpool(content)
+
+    def _sync_request_data(self, data: AsyncRequestData) -> RequestData:
+        if isinstance(data, bytes):
+            return data
+
+        return self.backend.iterate(data)
index 42ffd157ae1b8438d7ff5273f24a4dc8a2ee7584..13d118cfcdc9a3f11f4ce64e198b9f29fefda3aa 100644 (file)
@@ -6,6 +6,9 @@ from types import TracebackType
 from .config import CertTypes, PoolLimits, TimeoutConfig, TimeoutTypes, VerifyTypes
 from .models import (
     URL,
+    AsyncRequest,
+    AsyncRequestData,
+    AsyncResponse,
     Headers,
     HeaderTypes,
     QueryParamTypes,
@@ -21,9 +24,9 @@ class Protocol(str, enum.Enum):
     HTTP_2 = "HTTP/2"
 
 
-class Dispatcher:
+class AsyncDispatcher:
     """
-    Base class for dispatcher classes, that handle sending the request.
+    Base class for async dispatcher classes, that handle sending the request.
 
     Stubs out the interface, as well as providing a `.request()` convienence
     implementation, to make it easy to use or test stand-alone dispatchers,
@@ -31,6 +34,54 @@ class Dispatcher:
     """
 
     async def request(
+        self,
+        method: str,
+        url: URLTypes,
+        *,
+        data: AsyncRequestData = b"",
+        params: QueryParamTypes = None,
+        headers: HeaderTypes = None,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None
+    ) -> AsyncResponse:
+        request = AsyncRequest(method, url, data=data, params=params, headers=headers)
+        return await self.send(request, verify=verify, cert=cert, timeout=timeout)
+
+    async def send(
+        self,
+        request: AsyncRequest,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
+    ) -> AsyncResponse:
+        raise NotImplementedError()  # pragma: nocover
+
+    async def close(self) -> None:
+        pass  # pragma: nocover
+
+    async def __aenter__(self) -> "AsyncDispatcher":
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: typing.Type[BaseException] = None,
+        exc_value: BaseException = None,
+        traceback: TracebackType = None,
+    ) -> None:
+        await self.close()
+
+
+class Dispatcher:
+    """
+    Base class for syncronous dispatcher classes, that handle sending the request.
+
+    Stubs out the interface, as well as providing a `.request()` convienence
+    implementation, to make it easy to use or test stand-alone dispatchers,
+    without requiring a complete `Client` instance.
+    """
+
+    def request(
         self,
         method: str,
         url: URLTypes,
@@ -38,40 +89,35 @@ class Dispatcher:
         data: RequestData = b"",
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        stream: bool = False,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None
     ) -> Response:
         request = Request(method, url, data=data, params=params, headers=headers)
-        response = await self.send(
-            request, stream=stream, verify=verify, cert=cert, timeout=timeout
-        )
-        return response
+        return self.send(request, verify=verify, cert=cert, timeout=timeout)
 
-    async def send(
+    def send(
         self,
         request: Request,
-        stream: bool = False,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
     ) -> Response:
         raise NotImplementedError()  # pragma: nocover
 
-    async def close(self) -> None:
+    def close(self) -> None:
         pass  # pragma: nocover
 
-    async def __aenter__(self) -> "Dispatcher":
+    def __enter__(self) -> "Dispatcher":
         return self
 
-    async def __aexit__(
+    def __exit__(
         self,
         exc_type: typing.Type[BaseException] = None,
         exc_value: BaseException = None,
         traceback: TracebackType = None,
     ) -> None:
-        await self.close()
+        self.close()
 
 
 class BaseReader:
@@ -128,3 +174,36 @@ class ConcurrencyBackend:
 
     def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
         raise NotImplementedError()  # pragma: no cover
+
+    async def run_in_threadpool(
+        self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
+    ) -> typing.Any:
+        raise NotImplementedError()  # pragma: no cover
+
+    async def iterate_in_threadpool(self, iterator):  # type: ignore
+        class IterationComplete(Exception):
+            pass
+
+        def next_wrapper(iterator):  # type: ignore
+            try:
+                return next(iterator)
+            except StopIteration:
+                raise IterationComplete()
+
+        while True:
+            try:
+                yield await self.run_in_threadpool(next_wrapper, iterator)
+            except IterationComplete:
+                break
+
+    def run(
+        self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
+    ) -> typing.Any:
+        raise NotImplementedError()  # pragma: no cover
+
+    def iterate(self, async_iterator):  # type: ignore
+        while True:
+            try:
+                yield self.run(async_iterator.__anext__)
+            except StopAsyncIteration:
+                break
index cb28675d2db4e51c865e320d35f61e709515c095..eb610801bc4f9a3da62cd8a91e189588a634bf0c 100644 (file)
@@ -1,4 +1,3 @@
-import asyncio
 import cgi
 import email.message
 import json as jsonlib
@@ -48,12 +47,16 @@ CookieTypes = typing.Union["Cookies", CookieJar, typing.Dict[str, str]]
 
 AuthTypes = typing.Union[
     typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
-    typing.Callable[["Request"], "Request"],
+    typing.Callable[["AsyncRequest"], "AsyncRequest"],
 ]
 
-RequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
+AsyncRequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
 
-ResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]]
+RequestData = typing.Union[dict, bytes, typing.Iterator[bytes]]
+
+AsyncResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]]
+
+ResponseContent = typing.Union[bytes, typing.Iterator[bytes]]
 
 
 class URL:
@@ -469,14 +472,12 @@ class Headers(typing.MutableMapping[str, str]):
         return f"{class_name}({as_list!r}{encoding_str})"
 
 
-class Request:
+class BaseRequest:
     def __init__(
         self,
         method: str,
         url: typing.Union[str, URL],
         *,
-        data: RequestData = b"",
-        json: typing.Any = None,
         params: QueryParamTypes = None,
         headers: HeaderTypes = None,
         cookies: CookieTypes = None,
@@ -488,18 +489,82 @@ class Request:
             self._cookies = Cookies(cookies)
             self._cookies.set_cookie_header(self)
 
+    def encode_json(self, json: typing.Any) -> bytes:
+        return jsonlib.dumps(json).encode("utf-8")
+
+    def urlencode_data(self, data: dict) -> bytes:
+        return urlencode(data, doseq=True).encode("utf-8")
+
+    def prepare(self) -> None:
+        content = getattr(self, "content", None)  # type: bytes
+        is_streaming = getattr(self, "is_streaming", False)
+
+        auto_headers = []  # type: typing.List[typing.Tuple[bytes, bytes]]
+
+        has_user_agent = "user-agent" in self.headers
+        has_accept = "accept" in self.headers
+        has_content_length = (
+            "content-length" in self.headers or "transfer-encoding" in self.headers
+        )
+        has_accept_encoding = "accept-encoding" in self.headers
+
+        if not has_user_agent:
+            auto_headers.append((b"user-agent", b"httpcore"))
+        if not has_accept:
+            auto_headers.append((b"accept", b"*/*"))
+        if not has_content_length:
+            if is_streaming:
+                auto_headers.append((b"transfer-encoding", b"chunked"))
+            elif content:
+                content_length = str(len(content)).encode()
+                auto_headers.append((b"content-length", content_length))
+        if not has_accept_encoding:
+            auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
+
+        for item in reversed(auto_headers):
+            self.headers.raw.insert(0, item)
+
+    @property
+    def cookies(self) -> "Cookies":
+        if not hasattr(self, "_cookies"):
+            self._cookies = Cookies()
+        return self._cookies
+
+    def __repr__(self) -> str:
+        class_name = self.__class__.__name__
+        url = str(self.url)
+        return f"<{class_name}({self.method!r}, {url!r})>"
+
+
+class AsyncRequest(BaseRequest):
+    def __init__(
+        self,
+        method: str,
+        url: typing.Union[str, URL],
+        *,
+        params: QueryParamTypes = None,
+        headers: HeaderTypes = None,
+        cookies: CookieTypes = None,
+        data: AsyncRequestData = b"",
+        json: typing.Any = None,
+    ):
+        super().__init__(
+            method=method, url=url, params=params, headers=headers, cookies=cookies
+        )
+
         if json is not None:
-            data = jsonlib.dumps(json).encode("utf-8")
+            self.is_streaming = False
+            self.content = self.encode_json(json)
             self.headers["Content-Type"] = "application/json"
-
-        if isinstance(data, bytes):
+        elif isinstance(data, bytes):
             self.is_streaming = False
             self.content = data
         elif isinstance(data, dict):
             self.is_streaming = False
-            self.content = urlencode(data, doseq=True).encode("utf-8")
+            self.content = self.urlencode_data(data)
             self.headers["Content-Type"] = "application/x-www-form-urlencoded"
         else:
+            assert hasattr(data, "__aiter__")
             self.is_streaming = True
             self.content_aiter = data
 
@@ -520,39 +585,55 @@ class Request:
         elif self.content:
             yield self.content
 
-    def prepare(self) -> None:
-        auto_headers = []  # type: typing.List[typing.Tuple[bytes, bytes]]
 
-        has_content_length = (
-            "content-length" in self.headers or "transfer-encoding" in self.headers
+class Request(BaseRequest):
+    def __init__(
+        self,
+        method: str,
+        url: typing.Union[str, URL],
+        *,
+        params: QueryParamTypes = None,
+        headers: HeaderTypes = None,
+        cookies: CookieTypes = None,
+        data: RequestData = b"",
+        json: typing.Any = None,
+    ):
+        super().__init__(
+            method=method, url=url, params=params, headers=headers, cookies=cookies
         )
-        has_accept_encoding = "accept-encoding" in self.headers
 
-        if not has_content_length:
-            if self.is_streaming:
-                auto_headers.append((b"transfer-encoding", b"chunked"))
-            elif self.content:
-                content_length = str(len(self.content)).encode()
-                auto_headers.append((b"content-length", content_length))
-        if not has_accept_encoding:
-            auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
+        if json is not None:
+            self.is_streaming = False
+            self.content = self.encode_json(json)
+            self.headers["Content-Type"] = "application/json"
+        elif isinstance(data, bytes):
+            self.is_streaming = False
+            self.content = data
+        elif isinstance(data, dict):
+            self.is_streaming = False
+            self.content = self.urlencode_data(data)
+            self.headers["Content-Type"] = "application/x-www-form-urlencoded"
+        else:
+            assert hasattr(data, "__iter__")
+            self.is_streaming = True
+            self.content_iter = data
 
-        for item in reversed(auto_headers):
-            self.headers.raw.insert(0, item)
+        self.prepare()
 
-    @property
-    def cookies(self) -> "Cookies":
-        if not hasattr(self, "_cookies"):
-            self._cookies = Cookies()
-        return self._cookies
+    def read(self) -> bytes:
+        if not hasattr(self, "content"):
+            self.content = b"".join([part for part in self.stream()])
+        return self.content
 
-    def __repr__(self) -> str:
-        class_name = self.__class__.__name__
-        url = str(self.url)
-        return f"<{class_name}({self.method!r}, {url!r})>"
+    def stream(self) -> typing.Iterator[bytes]:
+        if self.is_streaming:
+            for part in self.content_iter:
+                yield part
+        elif self.content:
+            yield self.content
 
 
-class Response:
+class BaseResponse:
     def __init__(
         self,
         status_code: int,
@@ -560,28 +641,16 @@ class Response:
         reason_phrase: str = None,
         protocol: str = None,
         headers: HeaderTypes = None,
-        content: ResponseContent = b"",
+        request: BaseRequest = None,
         on_close: typing.Callable = None,
-        request: Request = None,
-        history: typing.List["Response"] = None,
     ):
         self.status_code = StatusCode.enum_or_int(status_code)
         self.reason_phrase = StatusCode.get_reason_phrase(status_code)
         self.protocol = protocol
         self.headers = Headers(headers)
 
-        if isinstance(content, bytes):
-            self.is_closed = True
-            self.is_stream_consumed = True
-            self._raw_content = content
-        else:
-            self.is_closed = False
-            self.is_stream_consumed = False
-            self._raw_stream = content
-
-        self.on_close = on_close
         self.request = request
-        self.history = [] if history is None else list(history)
+        self.on_close = on_close
         self.next = None  # typing.Optional[typing.Callable]
 
     @property
@@ -597,7 +666,8 @@ class Response:
     def content(self) -> bytes:
         if not hasattr(self, "_content"):
             if hasattr(self, "_raw_content"):
-                content = self.decoder.decode(self._raw_content)
+                raw_content = getattr(self, "_raw_content")  # type: bytes
+                content = self.decoder.decode(raw_content)
                 content += self.decoder.flush()
                 self._content = content
             else:
@@ -682,6 +752,77 @@ class Response:
 
         return self._decoder
 
+    @property
+    def is_redirect(self) -> bool:
+        return StatusCode.is_redirect(self.status_code) and "location" in self.headers
+
+    def raise_for_status(self) -> None:
+        """
+        Raise the `HttpError` if one occurred.
+        """
+        message = (
+            "{0.status_code} {error_type}: {0.reason_phrase} for url: {0.url}\n"
+            "For more information check: https://httpstatuses.com/{0.status_code}"
+        )
+
+        if StatusCode.is_client_error(self.status_code):
+            message = message.format(self, error_type="Client Error")
+        elif StatusCode.is_server_error(self.status_code):
+            message = message.format(self, error_type="Server Error")
+        else:
+            message = ""
+
+        if message:
+            raise HttpError(message)
+
+    def json(self) -> typing.Any:
+        return jsonlib.loads(self.content.decode("utf-8"))
+
+    @property
+    def cookies(self) -> "Cookies":
+        if not hasattr(self, "_cookies"):
+            assert self.request is not None
+            self._cookies = Cookies()
+            self._cookies.extract_cookies(self)
+        return self._cookies
+
+    def __repr__(self) -> str:
+        return f"<Response({self.status_code}, {self.reason_phrase!r})>"
+
+
+class AsyncResponse(BaseResponse):
+    def __init__(
+        self,
+        status_code: int,
+        *,
+        reason_phrase: str = None,
+        protocol: str = None,
+        headers: HeaderTypes = None,
+        content: AsyncResponseContent = b"",
+        on_close: typing.Callable = None,
+        request: AsyncRequest = None,
+        history: typing.List["BaseResponse"] = None,
+    ):
+        super().__init__(
+            status_code=status_code,
+            reason_phrase=reason_phrase,
+            protocol=protocol,
+            headers=headers,
+            request=request,
+            on_close=on_close,
+        )
+
+        self.history = [] if history is None else list(history)
+
+        if isinstance(content, bytes):
+            self.is_closed = True
+            self.is_stream_consumed = True
+            self._raw_content = content
+        else:
+            self.is_closed = False
+            self.is_stream_consumed = False
+            self._raw_stream = content
+
     async def read(self) -> bytes:
         """
         Read and return the response content.
@@ -729,128 +870,86 @@ class Response:
             if self.on_close is not None:
                 await self.on_close()
 
-    @property
-    def is_redirect(self) -> bool:
-        return StatusCode.is_redirect(self.status_code) and "location" in self.headers
 
-    def raise_for_status(self) -> None:
-        """
-        Raise the `HttpError` if one occurred.
-        """
-        message = (
-            "{0.status_code} {error_type}: {0.reason_phrase} for url: {0.url}\n"
-            "For more information check: https://httpstatuses.com/{0.status_code}"
+class Response(BaseResponse):
+    def __init__(
+        self,
+        status_code: int,
+        *,
+        reason_phrase: str = None,
+        protocol: str = None,
+        headers: HeaderTypes = None,
+        content: ResponseContent = b"",
+        on_close: typing.Callable = None,
+        request: Request = None,
+        history: typing.List["BaseResponse"] = None,
+    ):
+        super().__init__(
+            status_code=status_code,
+            reason_phrase=reason_phrase,
+            protocol=protocol,
+            headers=headers,
+            request=request,
+            on_close=on_close,
         )
 
-        if StatusCode.is_client_error(self.status_code):
-            message = message.format(self, error_type="Client Error")
-        elif StatusCode.is_server_error(self.status_code):
-            message = message.format(self, error_type="Server Error")
-        else:
-            message = ""
-
-        if message:
-            raise HttpError(message)
-
-    def json(self) -> typing.Any:
-        return jsonlib.loads(self.content.decode("utf-8"))
-
-    @property
-    def cookies(self) -> "Cookies":
-        if not hasattr(self, "_cookies"):
-            assert self.request is not None
-            self._cookies = Cookies()
-            self._cookies.extract_cookies(self)
-        return self._cookies
-
-    def __repr__(self) -> str:
-        return f"<Response({self.status_code}, {self.reason_phrase!r})>"
-
-
-class SyncResponse:
-    """
-    A thread-synchronous response. This class proxies onto a `Response`
-    instance, providing standard synchronous interfaces where required.
-    """
-
-    def __init__(self, response: Response, loop: asyncio.AbstractEventLoop):
-        self._response = response
-        self._loop = loop
-
-    @property
-    def status_code(self) -> int:
-        return self._response.status_code
-
-    @property
-    def reason_phrase(self) -> str:
-        return self._response.reason_phrase
-
-    @property
-    def protocol(self) -> typing.Optional[str]:
-        return self._response.protocol
-
-    @property
-    def url(self) -> typing.Optional[URL]:
-        return self._response.url
-
-    @property
-    def request(self) -> typing.Optional[Request]:
-        return self._response.request
-
-    @property
-    def headers(self) -> Headers:
-        return self._response.headers
-
-    @property
-    def content(self) -> bytes:
-        return self._response.content
-
-    @property
-    def text(self) -> str:
-        return self._response.text
-
-    @property
-    def encoding(self) -> str:
-        return self._response.encoding
-
-    @property
-    def is_redirect(self) -> bool:
-        return self._response.is_redirect
-
-    def raise_for_status(self) -> None:
-        return self._response.raise_for_status()
+        self.history = [] if history is None else list(history)
 
-    def json(self) -> typing.Any:
-        return self._response.json()
+        if isinstance(content, bytes):
+            self.is_closed = True
+            self.is_stream_consumed = True
+            self._raw_content = content
+        else:
+            self.is_closed = False
+            self.is_stream_consumed = False
+            self._raw_stream = content
 
     def read(self) -> bytes:
-        return self._loop.run_until_complete(self._response.read())
+        """
+        Read and return the response content.
+        """
+        if not hasattr(self, "_content"):
+            self._content = b"".join([part for part in self.stream()])
+        return self._content
 
     def stream(self) -> typing.Iterator[bytes]:
-        inner = self._response.stream()
-        while True:
-            try:
-                yield self._loop.run_until_complete(inner.__anext__())
-            except StopAsyncIteration:
-                break
+        """
+        A byte-iterator over the decoded response content.
+        This allows us to handle gzip, deflate, and brotli encoded responses.
+        """
+        if hasattr(self, "_content"):
+            yield self._content
+        else:
+            for chunk in self.raw():
+                yield self.decoder.decode(chunk)
+            yield self.decoder.flush()
 
     def raw(self) -> typing.Iterator[bytes]:
-        inner = self._response.raw()
-        while True:
-            try:
-                yield self._loop.run_until_complete(inner.__anext__())
-            except StopAsyncIteration:
-                break
-
-    def close(self) -> None:
-        return self._loop.run_until_complete(self._response.close())
+        """
+        A byte-iterator over the raw response content.
+        """
+        if hasattr(self, "_raw_content"):
+            yield self._raw_content
+        else:
+            if self.is_stream_consumed:
+                raise StreamConsumed()
+            if self.is_closed:
+                raise ResponseClosed()
 
-    @property
-    def cookies(self) -> "Cookies":
-        return self._response.cookies
+            self.is_stream_consumed = True
+            for part in self._raw_stream:
+                yield part
+            self.close()
 
-    def __repr__(self) -> str:
-        return f"<Response({self.status_code}, {self.reason_phrase!r})>"
+    def close(self) -> None:
+        """
+        Close the response and release the connection.
+        Automatically called if the response body is read to completion.
+        """
+        if not self.is_closed:
+            self.is_closed = True
+            if self.on_close is not None:
+                self.on_close()
 
 
 class Cookies(MutableMapping):
@@ -871,7 +970,7 @@ class Cookies(MutableMapping):
         else:
             self.jar = cookies
 
-    def extract_cookies(self, response: Response) -> None:
+    def extract_cookies(self, response: BaseResponse) -> None:
         """
         Loads any cookies based on the response `Set-Cookie` headers.
         """
@@ -881,7 +980,7 @@ class Cookies(MutableMapping):
 
         self.jar.extract_cookies(urlib_response, urllib_request)  # type: ignore
 
-    def set_cookie_header(self, request: Request) -> None:
+    def set_cookie_header(self, request: BaseRequest) -> None:
         """
         Sets an appropriate 'Cookie:' HTTP header on the `Request`.
         """
@@ -1000,7 +1099,7 @@ class Cookies(MutableMapping):
         for use with `CookieJar` operations.
         """
 
-        def __init__(self, request: Request) -> None:
+        def __init__(self, request: BaseRequest) -> None:
             super().__init__(
                 url=str(request.url),
                 headers=dict(request.headers),
@@ -1018,7 +1117,7 @@ class Cookies(MutableMapping):
         for use with `CookieJar` operations.
         """
 
-        def __init__(self, response: Response):
+        def __init__(self, response: BaseResponse):
             self.response = response
 
         def info(self) -> email.message.Message:
index 1d2b97239c82d486c3418a4e9d5aa7a01f32ba42..17993383a9f904012c1fe4e05ba408bc3476e0f6 100644 (file)
@@ -4,27 +4,26 @@ import pytest
 
 from httpcore import (
     URL,
+    AsyncDispatcher,
+    AsyncRequest,
+    AsyncResponse,
     CertTypes,
     Client,
-    Dispatcher,
-    Request,
-    Response,
     TimeoutTypes,
     VerifyTypes,
 )
 
 
-class MockDispatch(Dispatcher):
+class MockDispatch(AsyncDispatcher):
     async def send(
         self,
-        request: Request,
-        stream: bool = False,
+        request: AsyncRequest,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> Response:
+    ) -> AsyncResponse:
         body = json.dumps({"auth": request.headers.get("Authorization")}).encode()
-        return Response(200, content=body, request=request)
+        return AsyncResponse(200, content=body, request=request)
 
 
 def test_basic_auth():
index a21f5c134f878f7a3ed22099bb1eec5b72ef49d2..5cbb380921d6dba42e2850a4a846c405516a67e3 100644 (file)
@@ -5,32 +5,31 @@ import pytest
 
 from httpcore import (
     URL,
+    AsyncDispatcher,
+    AsyncRequest,
+    AsyncResponse,
     CertTypes,
     Client,
     Cookies,
-    Dispatcher,
-    Request,
-    Response,
     TimeoutTypes,
     VerifyTypes,
 )
 
 
-class MockDispatch(Dispatcher):
+class MockDispatch(AsyncDispatcher):
     async def send(
         self,
-        request: Request,
-        stream: bool = False,
+        request: AsyncRequest,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> Response:
+    ) -> AsyncResponse:
         if request.url.path.startswith("/echo_cookies"):
             body = json.dumps({"cookies": request.headers.get("Cookie")}).encode()
-            return Response(200, content=body, request=request)
+            return AsyncResponse(200, content=body, request=request)
         elif request.url.path.startswith("/set_cookie"):
             headers = {"set-cookie": "example-name=example-value"}
-            return Response(200, headers=headers, request=request)
+            return AsyncResponse(200, headers=headers, request=request)
 
 
 def test_set_cookie():
index c3b384dc954053c191f2c4b273dadab0d7a24bf3..3f5168974a01d517026f9b2177b2f07a6d978d8f 100644 (file)
@@ -6,8 +6,10 @@ import pytest
 from httpcore import (
     URL,
     AsyncClient,
+    AsyncDispatcher,
+    AsyncRequest,
+    AsyncResponse,
     CertTypes,
-    Dispatcher,
     RedirectBodyUnavailable,
     RedirectLoop,
     Request,
@@ -19,37 +21,36 @@ from httpcore import (
 )
 
 
-class MockDispatch(Dispatcher):
+class MockDispatch(AsyncDispatcher):
     async def send(
         self,
-        request: Request,
-        stream: bool = False,
+        request: AsyncRequest,
         verify: VerifyTypes = None,
         cert: CertTypes = None,
         timeout: TimeoutTypes = None,
-    ) -> Response:
+    ) -> AsyncResponse:
         if request.url.path == "/redirect_301":
             status_code = codes.MOVED_PERMANENTLY
             headers = {"location": "https://example.org/"}
-            return Response(status_code, headers=headers, request=request)
+            return AsyncResponse(status_code, headers=headers, request=request)
 
         elif request.url.path == "/redirect_302":
             status_code = codes.FOUND
             headers = {"location": "https://example.org/"}
-            return Response(status_code, headers=headers, request=request)
+            return AsyncResponse(status_code, headers=headers, request=request)
 
         elif request.url.path == "/redirect_303":
             status_code = codes.SEE_OTHER
             headers = {"location": "https://example.org/"}
-            return Response(status_code, headers=headers, request=request)
+            return AsyncResponse(status_code, headers=headers, request=request)
 
         elif request.url.path == "/relative_redirect":
             headers = {"location": "/"}
-            return Response(codes.SEE_OTHER, headers=headers, request=request)
+            return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
 
         elif request.url.path == "/no_scheme_redirect":
             headers = {"location": "//example.org/"}
-            return Response(codes.SEE_OTHER, headers=headers, request=request)
+            return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
 
         elif request.url.path == "/multiple_redirects":
             params = parse_qs(request.url.query)
@@ -60,32 +61,34 @@ class MockDispatch(Dispatcher):
             if redirect_count:
                 location += "?count=" + str(redirect_count)
             headers = {"location": location} if count else {}
-            return Response(code, headers=headers, request=request)
+            return AsyncResponse(code, headers=headers, request=request)
 
         if request.url.path == "/redirect_loop":
             headers = {"location": "/redirect_loop"}
-            return Response(codes.SEE_OTHER, headers=headers, request=request)
+            return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
 
         elif request.url.path == "/cross_domain":
             headers = {"location": "https://example.org/cross_domain_target"}
-            return Response(codes.SEE_OTHER, headers=headers, request=request)
+            return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
 
         elif request.url.path == "/cross_domain_target":
             headers = dict(request.headers.items())
             content = json.dumps({"headers": headers}).encode()
-            return Response(codes.OK, content=content, request=request)
+            return AsyncResponse(codes.OK, content=content, request=request)
 
         elif request.url.path == "/redirect_body":
             await request.read()
             headers = {"location": "/redirect_body_target"}
-            return Response(codes.PERMANENT_REDIRECT, headers=headers, request=request)
+            return AsyncResponse(
+                codes.PERMANENT_REDIRECT, headers=headers, request=request
+            )
 
         elif request.url.path == "/redirect_body_target":
             content = await request.read()
             body = json.dumps({"body": content.decode()}).encode()
-            return Response(codes.OK, content=body, request=request)
+            return AsyncResponse(codes.OK, content=body, request=request)
 
-        return Response(codes.OK, content=b"Hello, world!", request=request)
+        return AsyncResponse(codes.OK, content=b"Hello, world!", request=request)
 
 
 @pytest.mark.asyncio
index bbe200ba949e6ddca3a71e94e9152c27c72f2d7a..b8049c70b835bdf126625cd2ed5ebe737d4210da 100644 (file)
@@ -10,10 +10,12 @@ async def test_keepalive_connections(server):
     """
     async with httpcore.ConnectionPool() as http:
         response = await http.request("GET", "http://127.0.0.1:8000/")
+        await response.read()
         assert len(http.active_connections) == 0
         assert len(http.keepalive_connections) == 1
 
         response = await http.request("GET", "http://127.0.0.1:8000/")
+        await response.read()
         assert len(http.active_connections) == 0
         assert len(http.keepalive_connections) == 1
 
@@ -25,10 +27,12 @@ async def test_differing_connection_keys(server):
     """
     async with httpcore.ConnectionPool() as http:
         response = await http.request("GET", "http://127.0.0.1:8000/")
+        await response.read()
         assert len(http.active_connections) == 0
         assert len(http.keepalive_connections) == 1
 
         response = await http.request("GET", "http://localhost:8000/")
+        await response.read()
         assert len(http.active_connections) == 0
         assert len(http.keepalive_connections) == 2
 
@@ -42,10 +46,12 @@ async def test_soft_limit(server):
 
     async with httpcore.ConnectionPool(pool_limits=pool_limits) as http:
         response = await http.request("GET", "http://127.0.0.1:8000/")
+        await response.read()
         assert len(http.active_connections) == 0
         assert len(http.keepalive_connections) == 1
 
         response = await http.request("GET", "http://localhost:8000/")
+        await response.read()
         assert len(http.active_connections) == 0
         assert len(http.keepalive_connections) == 1
 
@@ -56,7 +62,7 @@ async def test_streaming_response_holds_connection(server):
     A streaming request should hold the connection open until the response is read.
     """
     async with httpcore.ConnectionPool() as http:
-        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+        response = await http.request("GET", "http://127.0.0.1:8000/")
         assert len(http.active_connections) == 1
         assert len(http.keepalive_connections) == 0
 
@@ -72,11 +78,11 @@ async def test_multiple_concurrent_connections(server):
     Multiple conncurrent requests should open multiple conncurrent connections.
     """
     async with httpcore.ConnectionPool() as http:
-        response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+        response_a = await http.request("GET", "http://127.0.0.1:8000/")
         assert len(http.active_connections) == 1
         assert len(http.keepalive_connections) == 0
 
-        response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+        response_b = await http.request("GET", "http://127.0.0.1:8000/")
         assert len(http.active_connections) == 2
         assert len(http.keepalive_connections) == 0
 
@@ -97,6 +103,7 @@ async def test_close_connections(server):
     headers = [(b"connection", b"close")]
     async with httpcore.ConnectionPool() as http:
         response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
+        await response.read()
         assert len(http.active_connections) == 0
         assert len(http.keepalive_connections) == 0
 
@@ -107,7 +114,7 @@ async def test_standard_response_close(server):
     A standard close should keep the connection open.
     """
     async with httpcore.ConnectionPool() as http:
-        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+        response = await http.request("GET", "http://127.0.0.1:8000/")
         await response.read()
         await response.close()
         assert len(http.active_connections) == 0
@@ -120,7 +127,7 @@ async def test_premature_response_close(server):
     A premature close should close the connection.
     """
     async with httpcore.ConnectionPool() as http:
-        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+        response = await http.request("GET", "http://127.0.0.1:8000/")
         await response.close()
         assert len(http.active_connections) == 0
         assert len(http.keepalive_connections) == 0
index ca401c7813316ef5c21574dbffbc11bb8d6e9a09..4b267f4fd97557988c49bdee11febde09be5a026 100644 (file)
@@ -7,6 +7,7 @@ from httpcore import HTTPConnection, Request, SSLConfig
 async def test_get(server):
     conn = HTTPConnection(origin="http://127.0.0.1:8000/")
     response = await conn.request("GET", "http://127.0.0.1:8000/")
+    await response.read()
     assert response.status_code == 200
     assert response.content == b"Hello, world!"
 
@@ -27,6 +28,7 @@ async def test_https_get_with_ssl_defaults(https_server):
     """
     conn = HTTPConnection(origin="https://127.0.0.1:8001/", verify=False)
     response = await conn.request("GET", "https://127.0.0.1:8001/")
+    await response.read()
     assert response.status_code == 200
     assert response.content == b"Hello, world!"
 
@@ -38,5 +40,6 @@ async def test_https_get_with_sll_overrides(https_server):
     """
     conn = HTTPConnection(origin="https://127.0.0.1:8001/")
     response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False)
+    await response.read()
     assert response.status_code == 200
     assert response.content == b"Hello, world!"
diff --git a/tests/dispatch/test_threaded.py b/tests/dispatch/test_threaded.py
new file mode 100644 (file)
index 0000000..d177dbb
--- /dev/null
@@ -0,0 +1,100 @@
+import json
+
+import pytest
+
+from httpcore import (
+    CertTypes,
+    Client,
+    Dispatcher,
+    Request,
+    Response,
+    TimeoutTypes,
+    VerifyTypes,
+)
+
+
+def streaming_body():
+    for part in [b"Hello", b", ", b"world!"]:
+        yield part
+
+
+class MockDispatch(Dispatcher):
+    def send(
+        self,
+        request: Request,
+        verify: VerifyTypes = None,
+        cert: CertTypes = None,
+        timeout: TimeoutTypes = None,
+    ) -> Response:
+        if request.url.path == "/streaming_response":
+            return Response(200, content=streaming_body(), request=request)
+        elif request.url.path == "/echo_request_body":
+            content = request.read()
+            return Response(200, content=content, request=request)
+        elif request.url.path == "/echo_request_body_streaming":
+            content = b"".join([part for part in request.stream()])
+            return Response(200, content=content, request=request)
+        else:
+            body = json.dumps({"hello": "world"}).encode()
+            return Response(200, content=body, request=request)
+
+
+def test_threaded_dispatch():
+    """
+    Use a syncronous 'Dispatcher' class with the client.
+    Calls to the dispatcher will end up running within a thread pool.
+    """
+    url = "https://example.org/"
+    with Client(dispatch=MockDispatch()) as client:
+        response = client.get(url)
+
+    assert response.status_code == 200
+    assert response.json() == {"hello": "world"}
+
+
+def test_threaded_streaming_response():
+    url = "https://example.org/streaming_response"
+    with Client(dispatch=MockDispatch()) as client:
+        response = client.get(url)
+
+    assert response.status_code == 200
+    assert response.text == "Hello, world!"
+
+
+def test_threaded_streaming_request():
+    url = "https://example.org/echo_request_body"
+    with Client(dispatch=MockDispatch()) as client:
+        response = client.post(url, data=streaming_body())
+
+    assert response.status_code == 200
+    assert response.text == "Hello, world!"
+
+
+def test_threaded_request_body():
+    url = "https://example.org/echo_request_body"
+    with Client(dispatch=MockDispatch()) as client:
+        response = client.post(url, data=b"Hello, world!")
+
+    assert response.status_code == 200
+    assert response.text == "Hello, world!"
+
+
+def test_threaded_request_body_streaming():
+    url = "https://example.org/echo_request_body_streaming"
+    with Client(dispatch=MockDispatch()) as client:
+        response = client.post(url, data=b"Hello, world!")
+
+    assert response.status_code == 200
+    assert response.text == "Hello, world!"
+
+
+def test_dispatch_class():
+    """
+    Use a syncronous 'Dispatcher' class directly.
+    """
+    url = "https://example.org/"
+    with MockDispatch() as dispatcher:
+        response = dispatcher.request("GET", url)
+
+    assert response.status_code == 200
+    assert response.json() == {"hello": "world"}
index d0d521a4680ff03149158681b33a0073e3a9189f..79cbba36e802b806c42e059146582cff362e301e 100644 (file)
@@ -10,87 +10,62 @@ def test_request_repr():
 
 def test_no_content():
     request = httpcore.Request("GET", "http://example.org")
-    request.prepare()
-    assert request.headers == httpcore.Headers(
-        [(b"accept-encoding", b"deflate, gzip, br")]
-    )
+    assert "Content-Length" not in request.headers
 
 
 def test_content_length_header():
     request = httpcore.Request("POST", "http://example.org", data=b"test 123")
-    request.prepare()
-    assert request.headers == httpcore.Headers(
-        [
-            (b"content-length", b"8"),
-            (b"accept-encoding", b"deflate, gzip, br"),
-        ]
-    )
+    assert request.headers["Content-Length"] == "8"
 
 
 def test_url_encoded_data():
-    request = httpcore.Request("POST", "http://example.org", data={"test": "123"})
-    request.prepare()
-    assert request.headers == httpcore.Headers(
-        [
-            (b"content-length", b"8"),
-            (b"accept-encoding", b"deflate, gzip, br"),
-            (b"content-type", b"application/x-www-form-urlencoded"),
-        ]
-    )
-    assert request.content == b"test=123"
+    for RequestClass in (httpcore.Request, httpcore.AsyncRequest):
+        request = RequestClass("POST", "http://example.org", data={"test": "123"})
+        assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
+        assert request.content == b"test=123"
+
+
+def test_json_encoded_data():
+    for RequestClass in (httpcore.Request, httpcore.AsyncRequest):
+        request = RequestClass("POST", "http://example.org", json={"test": 123})
+        assert request.headers["Content-Type"] == "application/json"
+        assert request.content == b'{"test": 123}'
 
 
 def test_transfer_encoding_header():
-    async def streaming_body(data):
+    def streaming_body(data):
         yield data  # pragma: nocover
 
     data = streaming_body(b"test 123")
 
     request = httpcore.Request("POST", "http://example.org", data=data)
-    request.prepare()
-    assert request.headers == httpcore.Headers(
-        [
-            (b"transfer-encoding", b"chunked"),
-            (b"accept-encoding", b"deflate, gzip, br"),
-        ]
-    )
+    assert "Content-Length" not in request.headers
+    assert request.headers["Transfer-Encoding"] == "chunked"
 
 
 def test_override_host_header():
-    headers = [(b"host", b"1.2.3.4:80")]
+    headers = {"host": "1.2.3.4:80"}
 
     request = httpcore.Request("GET", "http://example.org", headers=headers)
-    request.prepare()
-    assert request.headers == httpcore.Headers(
-        [(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")]
-    )
+    assert request.headers["Host"] == "1.2.3.4:80"
 
 
 def test_override_accept_encoding_header():
-    headers = [(b"accept-encoding", b"identity")]
+    headers = {"Accept-Encoding": "identity"}
 
     request = httpcore.Request("GET", "http://example.org", headers=headers)
-    request.prepare()
-    assert request.headers == httpcore.Headers(
-        [(b"accept-encoding", b"identity")]
-    )
+    assert request.headers["Accept-Encoding"] == "identity"
 
 
 def test_override_content_length_header():
-    async def streaming_body(data):
+    def streaming_body(data):
         yield data  # pragma: nocover
 
     data = streaming_body(b"test 123")
-    headers = [(b"content-length", b"8")]
+    headers = {"Content-Length": "8"}
 
     request = httpcore.Request("POST", "http://example.org", data=data, headers=headers)
-    request.prepare()
-    assert request.headers == httpcore.Headers(
-        [
-            (b"accept-encoding", b"deflate, gzip, br"),
-            (b"content-length", b"8"),
-        ]
-    )
+    assert request.headers["Content-Length"] == "8"
 
 
 def test_url():
index 8ecd37ab5c4aee4be5475f5331a14fa3f62a5ee2..f2d080ffc81cf82a4b0c679d7b78a1639e4028df 100644 (file)
@@ -3,7 +3,12 @@ import pytest
 import httpcore
 
 
-async def streaming_body():
+def streaming_body():
+    yield b"Hello, "
+    yield b"world!"
+
+
+async def async_streaming_body():
     yield b"Hello, "
     yield b"world!"
 
@@ -105,8 +110,7 @@ def test_response_force_encoding():
     assert response.encoding == "iso-8859-1"
 
 
-@pytest.mark.asyncio
-async def test_read_response():
+def test_read_response():
     response = httpcore.Response(200, content=b"Hello, world!")
 
     assert response.status_code == 200
@@ -114,37 +118,56 @@ async def test_read_response():
     assert response.encoding == "ascii"
     assert response.is_closed
 
-    content = await response.read()
+    content = response.read()
 
     assert content == b"Hello, world!"
     assert response.content == b"Hello, world!"
     assert response.is_closed
 
 
-@pytest.mark.asyncio
-async def test_raw_interface():
+def test_raw_interface():
     response = httpcore.Response(200, content=b"Hello, world!")
 
     raw = b""
-    async for part in response.raw():
+    for part in response.raw():
         raw += part
     assert raw == b"Hello, world!"
 
 
-@pytest.mark.asyncio
-async def test_stream_interface():
+def test_stream_interface():
     response = httpcore.Response(200, content=b"Hello, world!")
 
     content = b""
-    async for part in response.stream():
+    for part in response.stream():
         content += part
     assert content == b"Hello, world!"
 
 
 @pytest.mark.asyncio
-async def test_stream_interface_after_read():
+async def test_async_stream_interface():
+    response = httpcore.AsyncResponse(200, content=b"Hello, world!")
+
+    content = b""
+    async for part in response.stream():
+        content += part
+    assert content == b"Hello, world!"
+
+
+def test_stream_interface_after_read():
     response = httpcore.Response(200, content=b"Hello, world!")
 
+    response.read()
+
+    content = b""
+    for part in response.stream():
+        content += part
+    assert content == b"Hello, world!"
+
+
+@pytest.mark.asyncio
+async def test_async_stream_interface_after_read():
+    response = httpcore.AsyncResponse(200, content=b"Hello, world!")
+
     await response.read()
 
     content = b""
@@ -153,14 +176,13 @@ async def test_stream_interface_after_read():
     assert content == b"Hello, world!"
 
 
-@pytest.mark.asyncio
-async def test_streaming_response():
+def test_streaming_response():
     response = httpcore.Response(200, content=streaming_body())
 
     assert response.status_code == 200
     assert not response.is_closed
 
-    content = await response.read()
+    content = response.read()
 
     assert content == b"Hello, world!"
     assert response.content == b"Hello, world!"
@@ -168,9 +190,34 @@ async def test_streaming_response():
 
 
 @pytest.mark.asyncio
-async def test_cannot_read_after_stream_consumed():
+async def test_async_streaming_response():
+    response = httpcore.AsyncResponse(200, content=async_streaming_body())
+
+    assert response.status_code == 200
+    assert not response.is_closed
+
+    content = await response.read()
+
+    assert content == b"Hello, world!"
+    assert response.content == b"Hello, world!"
+    assert response.is_closed
+
+
+def test_cannot_read_after_stream_consumed():
     response = httpcore.Response(200, content=streaming_body())
 
+    content = b""
+    for part in response.stream():
+        content += part
+
+    with pytest.raises(httpcore.StreamConsumed):
+        response.read()
+
+
+@pytest.mark.asyncio
+async def test_async_cannot_read_after_stream_consumed():
+    response = httpcore.AsyncResponse(200, content=async_streaming_body())
+
     content = b""
     async for part in response.stream():
         content += part
@@ -179,10 +226,19 @@ async def test_cannot_read_after_stream_consumed():
         await response.read()
 
 
-@pytest.mark.asyncio
-async def test_cannot_read_after_response_closed():
+def test_cannot_read_after_response_closed():
     response = httpcore.Response(200, content=streaming_body())
 
+    response.close()
+
+    with pytest.raises(httpcore.ResponseClosed):
+        response.read()
+
+
+@pytest.mark.asyncio
+async def test_async_cannot_read_after_response_closed():
+    response = httpcore.AsyncResponse(200, content=async_streaming_body())
+
     await response.close()
 
     with pytest.raises(httpcore.ResponseClosed):
index 6a62359c16753b2da3c70bcc98568aa4faba97a3..1247a41602a658815805e5e5234034499fdfdd99 100644 (file)
@@ -38,6 +38,18 @@ def test_post(server):
     assert response.reason_phrase == "OK"
 
 
+@threadpool
+def test_post_byte_iterator(server):
+    def data():
+        yield b"Hello"
+        yield b", "
+        yield b"world!"
+
+    response = httpcore.post("http://127.0.0.1:8000/", data=data())
+    assert response.status_code == 200
+    assert response.reason_phrase == "OK"
+
+
 @threadpool
 def test_options(server):
     response = httpcore.options("http://127.0.0.1:8000/")
index 20273eec26efd09fcb352d49d7cd73b29806b64d..ac795ca91e1ad5c9465a84171e6d0fc1d6d44a1d 100644 (file)
@@ -64,19 +64,18 @@ def test_multi_with_identity():
     assert response.content == body
 
 
-@pytest.mark.asyncio
-async def test_streaming():
+def test_streaming():
     body = b"test 123"
     compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
 
-    async def compress(body):
+    def compress(body):
         yield compressor.compress(body)
         yield compressor.flush()
 
     headers = [(b"Content-Encoding", b"gzip")]
     response = httpcore.Response(200, headers=headers, content=compress(body))
     assert not hasattr(response, "body")
-    assert await response.read() == body
+    assert response.read() == body
 
 
 @pytest.mark.parametrize("header_value", (b"deflate", b"gzip", b"br"))