TooManyRedirects,
WriteTimeout,
)
-from .interfaces import BaseReader, BaseWriter, ConcurrencyBackend, Dispatcher, Protocol
-from .models import URL, Cookies, Headers, Origin, QueryParams, Request, Response
+from .interfaces import (
+ AsyncDispatcher,
+ BaseReader,
+ BaseWriter,
+ ConcurrencyBackend,
+ Dispatcher,
+ Protocol,
+)
+from .models import (
+ URL,
+ AsyncRequest,
+ AsyncResponse,
+ Cookies,
+ Headers,
+ Origin,
+ QueryParams,
+ Request,
+ Response,
+)
from .status_codes import StatusCode, codes
__version__ = "0.4.0"
HeaderTypes,
QueryParamTypes,
RequestData,
- SyncResponse,
+ Response,
URLTypes,
)
cert: CertTypes = None,
verify: VerifyTypes = True,
stream: bool = False,
-) -> SyncResponse:
+) -> Response:
with Client() as client:
return client.request(
method=method,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
-) -> SyncResponse:
+) -> Response:
return request(
"GET",
url,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
-) -> SyncResponse:
+) -> Response:
return request(
"OPTIONS",
url,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
-) -> SyncResponse:
+) -> Response:
return request(
"HEAD",
url,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
-) -> SyncResponse:
+) -> Response:
return request(
"POST",
url,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
-) -> SyncResponse:
+) -> Response:
return request(
"PUT",
url,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
-) -> SyncResponse:
+) -> Response:
return request(
"PATCH",
url,
cert: CertTypes = None,
verify: VerifyTypes = True,
timeout: TimeoutTypes = None,
-) -> SyncResponse:
+) -> Response:
return request(
"DELETE",
url,
import typing
from base64 import b64encode
-from .models import Request
+from .models import AsyncRequest
class AuthBase:
Base class that all auth implementations derive from.
"""
- def __call__(self, request: Request) -> Request:
+ def __call__(self, request: AsyncRequest) -> AsyncRequest:
raise NotImplementedError("Auth hooks must be callable.") # pragma: nocover
self.username = username
self.password = password
- def __call__(self, request: Request) -> Request:
+ def __call__(self, request: AsyncRequest) -> AsyncRequest:
request.headers["Authorization"] = self.build_auth_header()
return request
-import asyncio
import typing
from types import TracebackType
from .auth import HTTPBasicAuth
+from .concurrency import AsyncioBackend
from .config import (
DEFAULT_MAX_REDIRECTS,
DEFAULT_POOL_LIMITS,
VerifyTypes,
)
from .dispatch.connection_pool import ConnectionPool
+from .dispatch.threaded import ThreadedDispatcher
from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
-from .interfaces import ConcurrencyBackend, Dispatcher
+from .interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher
from .models import (
URL,
+ AsyncRequest,
+ AsyncRequestData,
+ AsyncResponse,
+ AsyncResponseContent,
AuthTypes,
Cookies,
CookieTypes,
Request,
RequestData,
Response,
- SyncResponse,
+ ResponseContent,
URLTypes,
)
from .status_codes import codes
-class AsyncClient:
+class BaseClient:
def __init__(
self,
auth: AuthTypes = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
- dispatch: Dispatcher = None,
+ dispatch: typing.Union[AsyncDispatcher, Dispatcher] = None,
backend: ConcurrencyBackend = None,
):
+ if backend is None:
+ backend = AsyncioBackend()
+
if dispatch is None:
- dispatch = ConnectionPool(
+ async_dispatch = ConnectionPool(
verify=verify,
cert=cert,
timeout=timeout,
pool_limits=pool_limits,
backend=backend,
- )
+ ) # type: AsyncDispatcher
+ elif isinstance(dispatch, Dispatcher):
+ async_dispatch = ThreadedDispatcher(dispatch, backend)
+ else:
+ async_dispatch = dispatch
self.auth = auth
self.cookies = Cookies(cookies)
self.max_redirects = max_redirects
- self.dispatch = dispatch
+ self.dispatch = async_dispatch
+ self.concurrency_backend = backend
+
+ def merge_cookies(
+ self, cookies: CookieTypes = None
+ ) -> typing.Optional[CookieTypes]:
+ if cookies or self.cookies:
+ merged_cookies = Cookies(self.cookies)
+ merged_cookies.update(cookies)
+ return merged_cookies
+ return cookies
+
+ async def send(
+ self,
+ request: AsyncRequest,
+ *,
+ stream: bool = False,
+ auth: AuthTypes = None,
+ allow_redirects: bool = True,
+ verify: VerifyTypes = None,
+ cert: CertTypes = None,
+ timeout: TimeoutTypes = None,
+ ) -> AsyncResponse:
+ if auth is None:
+ auth = self.auth
+
+ url = request.url
+ if auth is None and (url.username or url.password):
+ auth = HTTPBasicAuth(username=url.username, password=url.password)
+
+ if auth is not None:
+ if isinstance(auth, tuple):
+ auth = HTTPBasicAuth(username=auth[0], password=auth[1])
+ request = auth(request)
+
+ response = await self.send_handling_redirects(
+ request,
+ verify=verify,
+ cert=cert,
+ timeout=timeout,
+ allow_redirects=allow_redirects,
+ )
+
+ if not stream:
+ try:
+ await response.read()
+ finally:
+ await response.close()
+
+ return response
+
+ async def send_handling_redirects(
+ self,
+ request: AsyncRequest,
+ *,
+ cert: CertTypes = None,
+ verify: VerifyTypes = None,
+ timeout: TimeoutTypes = None,
+ allow_redirects: bool = True,
+ history: typing.List[AsyncResponse] = None,
+ ) -> AsyncResponse:
+ if history is None:
+ history = []
+
+ while True:
+ # We perform these checks here, so that calls to `response.next()`
+ # will raise redirect errors if appropriate.
+ if len(history) > self.max_redirects:
+ raise TooManyRedirects()
+ if request.url in [response.url for response in history]:
+ raise RedirectLoop()
+
+ response = await self.dispatch.send(
+ request, verify=verify, cert=cert, timeout=timeout
+ )
+ assert isinstance(response, AsyncResponse)
+ response.history = list(history)
+ self.cookies.extract_cookies(response)
+ history = [response] + history
+ if not response.is_redirect:
+ break
+
+ if allow_redirects:
+ request = self.build_redirect_request(request, response)
+ else:
+
+ async def send_next() -> AsyncResponse:
+ nonlocal request, response, verify, cert, allow_redirects, timeout, history
+ request = self.build_redirect_request(request, response)
+ response = await self.send_handling_redirects(
+ request,
+ allow_redirects=allow_redirects,
+ verify=verify,
+ cert=cert,
+ timeout=timeout,
+ history=history,
+ )
+ return response
+
+ response.next = send_next # type: ignore
+ break
+
+ return response
+
+ def build_redirect_request(
+ self, request: AsyncRequest, response: AsyncResponse
+ ) -> AsyncRequest:
+ method = self.redirect_method(request, response)
+ url = self.redirect_url(request, response)
+ headers = self.redirect_headers(request, url)
+ content = self.redirect_content(request, method)
+ cookies = self.merge_cookies(request.cookies)
+ return AsyncRequest(
+ method=method, url=url, headers=headers, data=content, cookies=cookies
+ )
+
+ def redirect_method(self, request: AsyncRequest, response: AsyncResponse) -> str:
+ """
+ When being redirected we may want to change the method of the request
+ based on certain specs or browser behavior.
+ """
+ method = request.method
+
+ # https://tools.ietf.org/html/rfc7231#section-6.4.4
+ if response.status_code == codes.SEE_OTHER and method != "HEAD":
+ method = "GET"
+
+ # Do what the browsers do, despite standards...
+ # Turn 302s into GETs.
+ if response.status_code == codes.FOUND and method != "HEAD":
+ method = "GET"
+
+ # If a POST is responded to with a 301, turn it into a GET.
+ # This bizarre behaviour is explained in 'requests' issue 1704.
+ if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
+ method = "GET"
+
+ return method
+
+ def redirect_url(self, request: AsyncRequest, response: AsyncResponse) -> URL:
+ """
+ Return the URL for the redirect to follow.
+ """
+ location = response.headers["Location"]
+
+ url = URL(location, allow_relative=True)
+
+ # Facilitate relative 'Location' headers, as allowed by RFC 7231.
+ # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
+ if url.is_relative_url:
+ url = url.resolve_with(request.url)
+
+ # Attach previous fragment if needed (RFC 7231 7.1.2)
+ if request.url.fragment and not url.fragment:
+ url = url.copy_with(fragment=request.url.fragment)
+
+ return url
+
+ def redirect_headers(self, request: AsyncRequest, url: URL) -> Headers:
+ """
+ Strip Authorization headers when responses are redirected away from
+ the origin.
+ """
+ headers = Headers(request.headers)
+ if url.origin != request.url.origin:
+ del headers["Authorization"]
+ return headers
+
+ def redirect_content(self, request: AsyncRequest, method: str) -> bytes:
+ """
+ Return the body that should be used for the redirect request.
+ """
+ if method != request.method and method == "GET":
+ return b""
+ if request.is_streaming:
+ raise RedirectBodyUnavailable()
+ return request.content
+
+class AsyncClient(BaseClient):
async def get(
self,
url: URLTypes,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> Response:
+ ) -> AsyncResponse:
return await self.request(
"GET",
url,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> Response:
+ ) -> AsyncResponse:
return await self.request(
"OPTIONS",
url,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> Response:
+ ) -> AsyncResponse:
return await self.request(
"HEAD",
url,
self,
url: URLTypes,
*,
- data: RequestData = b"",
+ data: AsyncRequestData = b"",
json: typing.Any = None,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> Response:
+ ) -> AsyncResponse:
return await self.request(
"POST",
url,
self,
url: URLTypes,
*,
- data: RequestData = b"",
+ data: AsyncRequestData = b"",
json: typing.Any = None,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> Response:
+ ) -> AsyncResponse:
return await self.request(
"PUT",
url,
self,
url: URLTypes,
*,
- data: RequestData = b"",
+ data: AsyncRequestData = b"",
json: typing.Any = None,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> Response:
+ ) -> AsyncResponse:
return await self.request(
"PATCH",
url,
self,
url: URLTypes,
*,
- data: RequestData = b"",
+ data: AsyncRequestData = b"",
json: typing.Any = None,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> Response:
+ ) -> AsyncResponse:
return await self.request(
"DELETE",
url,
method: str,
url: URLTypes,
*,
- data: RequestData = b"",
+ data: AsyncRequestData = b"",
json: typing.Any = None,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> Response:
- request = Request(
+ ) -> AsyncResponse:
+ request = AsyncRequest(
method,
url,
data=data,
)
return response
- def merge_cookies(
- self, cookies: CookieTypes = None
- ) -> typing.Optional[CookieTypes]:
- if cookies or self.cookies:
- merged_cookies = Cookies(self.cookies)
- merged_cookies.update(cookies)
- return merged_cookies
- return cookies
-
- async def send(
- self,
- request: Request,
- *,
- stream: bool = False,
- auth: AuthTypes = None,
- allow_redirects: bool = True,
- verify: VerifyTypes = None,
- cert: CertTypes = None,
- timeout: TimeoutTypes = None,
- ) -> Response:
- if auth is None:
- auth = self.auth
-
- url = request.url
- if auth is None and (url.username or url.password):
- auth = HTTPBasicAuth(username=url.username, password=url.password)
-
- if auth is not None:
- if isinstance(auth, tuple):
- auth = HTTPBasicAuth(username=auth[0], password=auth[1])
- request = auth(request)
-
- response = await self.send_handling_redirects(
- request,
- stream=stream,
- verify=verify,
- cert=cert,
- timeout=timeout,
- allow_redirects=allow_redirects,
- )
- return response
-
- async def send_handling_redirects(
- self,
- request: Request,
- *,
- stream: bool = False,
- cert: CertTypes = None,
- verify: VerifyTypes = None,
- timeout: TimeoutTypes = None,
- allow_redirects: bool = True,
- history: typing.List[Response] = None,
- ) -> Response:
- if history is None:
- history = []
-
- while True:
- # We perform these checks here, so that calls to `response.next()`
- # will raise redirect errors if appropriate.
- if len(history) > self.max_redirects:
- raise TooManyRedirects()
- if request.url in [response.url for response in history]:
- raise RedirectLoop()
-
- response = await self.dispatch.send(
- request, stream=stream, verify=verify, cert=cert, timeout=timeout
- )
- response.history = list(history)
- self.cookies.extract_cookies(response)
- history = [response] + history
- if not response.is_redirect:
- break
-
- if allow_redirects:
- request = self.build_redirect_request(request, response)
- else:
-
- async def send_next() -> Response:
- nonlocal request, response, verify, cert, allow_redirects, timeout, history
- request = self.build_redirect_request(request, response)
- response = await self.send_handling_redirects(
- request,
- stream=stream,
- allow_redirects=allow_redirects,
- verify=verify,
- cert=cert,
- timeout=timeout,
- history=history,
- )
- return response
-
- response.next = send_next # type: ignore
- break
-
- return response
-
- def build_redirect_request(self, request: Request, response: Response) -> Request:
- method = self.redirect_method(request, response)
- url = self.redirect_url(request, response)
- headers = self.redirect_headers(request, url)
- content = self.redirect_content(request, method)
- cookies = self.merge_cookies(request.cookies)
- return Request(
- method=method, url=url, headers=headers, data=content, cookies=cookies
- )
-
- def redirect_method(self, request: Request, response: Response) -> str:
- """
- When being redirected we may want to change the method of the request
- based on certain specs or browser behavior.
- """
- method = request.method
-
- # https://tools.ietf.org/html/rfc7231#section-6.4.4
- if response.status_code == codes.SEE_OTHER and method != "HEAD":
- method = "GET"
-
- # Do what the browsers do, despite standards...
- # Turn 302s into GETs.
- if response.status_code == codes.FOUND and method != "HEAD":
- method = "GET"
-
- # If a POST is responded to with a 301, turn it into a GET.
- # This bizarre behaviour is explained in 'requests' issue 1704.
- if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
- method = "GET"
-
- return method
-
- def redirect_url(self, request: Request, response: Response) -> URL:
- """
- Return the URL for the redirect to follow.
- """
- location = response.headers["Location"]
-
- url = URL(location, allow_relative=True)
-
- # Facilitate relative 'Location' headers, as allowed by RFC 7231.
- # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
- if url.is_relative_url:
- url = url.resolve_with(request.url)
-
- # Attach previous fragment if needed (RFC 7231 7.1.2)
- if request.url.fragment and not url.fragment:
- url = url.copy_with(fragment=request.url.fragment)
-
- return url
-
- def redirect_headers(self, request: Request, url: URL) -> Headers:
- """
- Strip Authorization headers when responses are redirected away from
- the origin.
- """
- headers = Headers(request.headers)
- if url.origin != request.url.origin:
- del headers["Authorization"]
- return headers
-
- def redirect_content(self, request: Request, method: str) -> bytes:
- """
- Return the body that should be used for the redirect request.
- """
- if method != request.method and method == "GET":
- return b""
- if request.is_streaming:
- raise RedirectBodyUnavailable()
- return request.content
-
async def close(self) -> None:
await self.dispatch.close()
await self.close()
-class Client:
- def __init__(
- self,
- auth: AuthTypes = None,
- cert: CertTypes = None,
- verify: VerifyTypes = True,
- timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
- pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
- max_redirects: int = DEFAULT_MAX_REDIRECTS,
- dispatch: Dispatcher = None,
- backend: ConcurrencyBackend = None,
- ) -> None:
- self._client = AsyncClient(
- auth=auth,
- verify=verify,
- cert=cert,
- timeout=timeout,
- pool_limits=pool_limits,
- max_redirects=max_redirects,
- dispatch=dispatch,
- backend=backend,
- )
- self._loop = asyncio.new_event_loop()
+class Client(BaseClient):
+ def _async_request_data(self, data: RequestData) -> AsyncRequestData:
+ """
+ If the request data is an bytes iterator then return an async bytes
+ iterator onto the request data.
+ """
+ if isinstance(data, (bytes, dict)):
+ return data
+
+ # Coerce an iterator into an async iterator, with each item in the
+ # iteration running as a thread-pooled operation.
+ assert hasattr(data, "__iter__")
+ return self.concurrency_backend.iterate_in_threadpool(data)
- @property
- def cookies(self) -> Cookies:
- return self._client.cookies
+ def _sync_data(self, data: AsyncResponseContent) -> ResponseContent:
+ if isinstance(data, bytes):
+ return data
+
+ # Coerce an async iterator into an iterator, with each item in the
+ # iteration run within the event loop.
+ assert hasattr(data, "__aiter__")
+ return self.concurrency_backend.iterate(data)
def request(
self,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> SyncResponse:
- request = Request(
+ ) -> Response:
+ request = AsyncRequest(
method,
url,
- data=data,
+ data=self._async_request_data(data),
json=json,
params=params,
headers=headers,
- cookies=self._client.merge_cookies(cookies),
+ cookies=self.merge_cookies(cookies),
)
- response = self.send(
- request,
- stream=stream,
+ concurrency_backend = self.concurrency_backend
+
+ coroutine = self.send
+ args = [request]
+ kwargs = dict(
+ stream=True,
auth=auth,
allow_redirects=allow_redirects,
verify=verify,
cert=cert,
timeout=timeout,
)
+ async_response = concurrency_backend.run(coroutine, *args, **kwargs)
+
+ content = getattr(
+ async_response, "_raw_content", getattr(async_response, "_raw_stream", None)
+ )
+
+ sync_content = self._sync_data(content)
+
+ def sync_on_close() -> None:
+ nonlocal concurrency_backend, async_response
+ concurrency_backend.run(async_response.on_close)
+
+ response = Response(
+ status_code=async_response.status_code,
+ reason_phrase=async_response.reason_phrase,
+ protocol=async_response.protocol,
+ headers=async_response.headers,
+ content=sync_content,
+ on_close=sync_on_close,
+ request=async_response.request,
+ history=async_response.history,
+ )
+ if not stream:
+ try:
+ response.read()
+ finally:
+ response.close()
return response
def get(
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> SyncResponse:
+ ) -> Response:
return self.request(
"GET",
url,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> SyncResponse:
+ ) -> Response:
return self.request(
"OPTIONS",
url,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> SyncResponse:
+ ) -> Response:
return self.request(
"HEAD",
url,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> SyncResponse:
+ ) -> Response:
return self.request(
"POST",
url,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> SyncResponse:
+ ) -> Response:
return self.request(
"PUT",
url,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> SyncResponse:
+ ) -> Response:
return self.request(
"PATCH",
url,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
- ) -> SyncResponse:
+ ) -> Response:
return self.request(
"DELETE",
url,
timeout=timeout,
)
- def send(
- self,
- request: Request,
- *,
- stream: bool = False,
- auth: AuthTypes = None,
- allow_redirects: bool = True,
- verify: VerifyTypes = None,
- cert: CertTypes = None,
- timeout: TimeoutTypes = None,
- ) -> SyncResponse:
- response = self._loop.run_until_complete(
- self._client.send(
- request,
- stream=stream,
- auth=auth,
- allow_redirects=allow_redirects,
- verify=verify,
- cert=cert,
- timeout=timeout,
- )
- )
- return SyncResponse(response, self._loop)
-
def close(self) -> None:
- self._loop.run_until_complete(self._client.close())
+ coroutine = self.dispatch.close
+ self.concurrency_backend.run(coroutine)
def __enter__(self) -> "Client":
return self
based, and less strictly `asyncio`-specific.
"""
import asyncio
+import functools
import ssl
import typing
ssl_monkey_patch()
SSL_MONKEY_PATCH_APPLIED = True
+ @property
+ def loop(self) -> asyncio.AbstractEventLoop:
+ if not hasattr(self, "_loop"):
+ try:
+ self._loop = asyncio.get_event_loop()
+ except RuntimeError:
+ self._loop = asyncio.new_event_loop()
+ return self._loop
+
async def connect(
self,
hostname: str,
return (reader, writer, protocol)
+ async def run_in_threadpool(
+ self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
+ ) -> typing.Any:
+ if kwargs:
+ # loop.run_in_executor doesn't accept 'kwargs', so bind them in here
+ func = functools.partial(func, **kwargs)
+ return await self.loop.run_in_executor(None, func, *args)
+
+ def run(
+ self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
+ ) -> typing.Any:
+ loop = self.loop
+ if loop.is_running():
+ self._loop = asyncio.new_event_loop()
+ try:
+ return self.loop.run_until_complete(coroutine(*args, **kwargs))
+ finally:
+ self._loop = loop
+
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
return PoolSemaphore(limits)
VerifyTypes,
)
from ..exceptions import ConnectTimeout
-from ..interfaces import ConcurrencyBackend, Dispatcher, Protocol
-from ..models import Origin, Request, Response
+from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Protocol
+from ..models import AsyncRequest, AsyncResponse, Origin
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection
ReleaseCallback = typing.Callable[["HTTPConnection"], typing.Awaitable[None]]
-class HTTPConnection(Dispatcher):
+class HTTPConnection(AsyncDispatcher):
def __init__(
self,
origin: typing.Union[str, Origin],
async def send(
self,
- request: Request,
- stream: bool = False,
+ request: AsyncRequest,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
- ) -> Response:
+ ) -> AsyncResponse:
if self.h11_connection is None and self.h2_connection is None:
await self.connect(verify=verify, cert=cert, timeout=timeout)
if self.h2_connection is not None:
- response = await self.h2_connection.send(
- request, stream=stream, timeout=timeout
- )
+ response = await self.h2_connection.send(request, timeout=timeout)
else:
assert self.h11_connection is not None
- response = await self.h11_connection.send(
- request, stream=stream, timeout=timeout
- )
+ response = await self.h11_connection.send(request, timeout=timeout)
return response
)
from ..decoders import ACCEPT_ENCODING
from ..exceptions import PoolTimeout
-from ..interfaces import ConcurrencyBackend, Dispatcher
-from ..models import Origin, Request, Response
+from ..interfaces import AsyncDispatcher, ConcurrencyBackend
+from ..models import AsyncRequest, AsyncResponse, Origin
from .connection import HTTPConnection
CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
return len(self.all)
-class ConnectionPool(Dispatcher):
+class ConnectionPool(AsyncDispatcher):
def __init__(
self,
*,
async def send(
self,
- request: Request,
- stream: bool = False,
+ request: AsyncRequest,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
- ) -> Response:
+ ) -> AsyncResponse:
connection = await self.acquire_connection(request.url.origin)
try:
response = await connection.send(
- request, stream=stream, verify=verify, cert=cert, timeout=timeout
+ request, verify=verify, cert=cert, timeout=timeout
)
except BaseException as exc:
self.active_connections.remove(connection)
from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
from ..exceptions import ConnectTimeout, ReadTimeout
-from ..interfaces import BaseReader, BaseWriter, Dispatcher
-from ..models import Request, Response
+from ..interfaces import BaseReader, BaseWriter
+from ..models import AsyncRequest, AsyncResponse
H11Event = typing.Union[
h11.Request,
self.h11_state = h11.Connection(our_role=h11.CLIENT)
async def send(
- self, request: Request, stream: bool = False, timeout: TimeoutTypes = None
- ) -> Response:
+ self, request: AsyncRequest, timeout: TimeoutTypes = None
+ ) -> AsyncResponse:
timeout = None if timeout is None else TimeoutConfig(timeout)
# Â Start sending the request.
method = request.method.encode("ascii")
target = request.url.full_path.encode("ascii")
headers = request.headers.raw
- if 'Host' not in request.headers:
+ if "Host" not in request.headers:
host = request.url.authority.encode("ascii")
headers = [(b"host", host)] + headers
event = h11.Request(method=method, target=target, headers=headers)
headers = event.headers
content = self._body_iter(timeout)
- response = Response(
+ return AsyncResponse(
status_code=status_code,
reason_phrase=reason_phrase,
protocol="HTTP/1.1",
request=request,
)
- if not stream:
- try:
- await response.read()
- finally:
- await response.close()
-
- return response
-
async def close(self) -> None:
event = h11.ConnectionClosed()
self.h11_state.send(event)
from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
from ..exceptions import ConnectTimeout, ReadTimeout
-from ..interfaces import BaseReader, BaseWriter, Dispatcher
-from ..models import Request, Response
+from ..interfaces import BaseReader, BaseWriter
+from ..models import AsyncRequest, AsyncResponse
class HTTP2Connection:
self.initialized = False
async def send(
- self, request: Request, stream: bool = False, timeout: TimeoutTypes = None
- ) -> Response:
+ self, request: AsyncRequest, timeout: TimeoutTypes = None
+ ) -> AsyncResponse:
timeout = None if timeout is None else TimeoutConfig(timeout)
# Â Start sending the request.
content = self.body_iter(stream_id, timeout)
on_close = functools.partial(self.response_closed, stream_id=stream_id)
- response = Response(
+ return AsyncResponse(
status_code=status_code,
protocol="HTTP/2",
headers=headers,
request=request,
)
- if not stream:
- try:
- await response.read()
- finally:
- await response.close()
-
- return response
-
async def close(self) -> None:
await self.writer.close()
self.initialized = True
async def send_headers(
- self, request: Request, timeout: TimeoutConfig = None
+ self, request: AsyncRequest, timeout: TimeoutConfig = None
) -> int:
stream_id = self.h2_state.get_next_available_stream_id()
headers = [
--- /dev/null
+from ..config import CertTypes, TimeoutTypes, VerifyTypes
+from ..interfaces import AsyncDispatcher, ConcurrencyBackend, Dispatcher
+from ..models import (
+ AsyncRequest,
+ AsyncRequestData,
+ AsyncResponse,
+ AsyncResponseContent,
+ Request,
+ RequestData,
+ Response,
+ ResponseContent,
+)
+
+
+class ThreadedDispatcher(AsyncDispatcher):
+ """
+ The ThreadedDispatcher class is used to mediate between the Client
+ (which always uses async under the hood), and a synchronous `Dispatch`
+ class.
+ """
+
+ def __init__(self, dispatch: Dispatcher, backend: ConcurrencyBackend) -> None:
+ self.sync_dispatcher = dispatch
+ self.backend = backend
+
+ async def send(
+ self,
+ request: AsyncRequest,
+ verify: VerifyTypes = None,
+ cert: CertTypes = None,
+ timeout: TimeoutTypes = None,
+ ) -> AsyncResponse:
+ concurrency_backend = self.backend
+
+ data = getattr(request, "content", getattr(request, "content_aiter", None))
+ sync_data = self._sync_request_data(data)
+
+ sync_request = Request(
+ method=request.method,
+ url=request.url,
+ headers=request.headers,
+ data=sync_data,
+ )
+
+ func = self.sync_dispatcher.send
+ kwargs = {
+ "request": sync_request,
+ "verify": verify,
+ "cert": cert,
+ "timeout": timeout,
+ }
+ sync_response = await self.backend.run_in_threadpool(func, **kwargs)
+ assert isinstance(sync_response, Response)
+
+ content = getattr(
+ sync_response, "_raw_content", getattr(sync_response, "_raw_stream", None)
+ )
+
+ async_content = self._async_response_content(content)
+
+ async def async_on_close() -> None:
+ nonlocal concurrency_backend, sync_response
+ await concurrency_backend.run_in_threadpool(sync_response.close)
+
+ return AsyncResponse(
+ status_code=sync_response.status_code,
+ reason_phrase=sync_response.reason_phrase,
+ protocol=sync_response.protocol,
+ headers=sync_response.headers,
+ content=async_content,
+ on_close=async_on_close,
+ request=request,
+ history=sync_response.history,
+ )
+
+ async def close(self) -> None:
+ """
+ The `.close()` method runs the `Dispatcher.close()` within a threadpool,
+ so as not to block the async event loop.
+ """
+ func = self.sync_dispatcher.close
+ await self.backend.run_in_threadpool(func)
+
+ def _async_response_content(self, content: ResponseContent) -> AsyncResponseContent:
+ if isinstance(content, bytes):
+ return content
+
+ # Coerce an async iterator into an iterator, with each item in the
+ # iteration run within the event loop.
+ assert hasattr(content, "__iter__")
+ return self.backend.iterate_in_threadpool(content)
+
+ def _sync_request_data(self, data: AsyncRequestData) -> RequestData:
+ if isinstance(data, bytes):
+ return data
+
+ return self.backend.iterate(data)
from .config import CertTypes, PoolLimits, TimeoutConfig, TimeoutTypes, VerifyTypes
from .models import (
URL,
+ AsyncRequest,
+ AsyncRequestData,
+ AsyncResponse,
Headers,
HeaderTypes,
QueryParamTypes,
HTTP_2 = "HTTP/2"
-class Dispatcher:
+class AsyncDispatcher:
"""
- Base class for dispatcher classes, that handle sending the request.
+ Base class for async dispatcher classes, that handle sending the request.
Stubs out the interface, as well as providing a `.request()` convienence
implementation, to make it easy to use or test stand-alone dispatchers,
"""
async def request(
+ self,
+ method: str,
+ url: URLTypes,
+ *,
+ data: AsyncRequestData = b"",
+ params: QueryParamTypes = None,
+ headers: HeaderTypes = None,
+ verify: VerifyTypes = None,
+ cert: CertTypes = None,
+ timeout: TimeoutTypes = None
+ ) -> AsyncResponse:
+ request = AsyncRequest(method, url, data=data, params=params, headers=headers)
+ return await self.send(request, verify=verify, cert=cert, timeout=timeout)
+
+ async def send(
+ self,
+ request: AsyncRequest,
+ verify: VerifyTypes = None,
+ cert: CertTypes = None,
+ timeout: TimeoutTypes = None,
+ ) -> AsyncResponse:
+ raise NotImplementedError() # pragma: nocover
+
+ async def close(self) -> None:
+ pass # pragma: nocover
+
+ async def __aenter__(self) -> "AsyncDispatcher":
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: typing.Type[BaseException] = None,
+ exc_value: BaseException = None,
+ traceback: TracebackType = None,
+ ) -> None:
+ await self.close()
+
+
+class Dispatcher:
+ """
+ Base class for syncronous dispatcher classes, that handle sending the request.
+
+ Stubs out the interface, as well as providing a `.request()` convienence
+ implementation, to make it easy to use or test stand-alone dispatchers,
+ without requiring a complete `Client` instance.
+ """
+
+ def request(
self,
method: str,
url: URLTypes,
data: RequestData = b"",
params: QueryParamTypes = None,
headers: HeaderTypes = None,
- stream: bool = False,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None
) -> Response:
request = Request(method, url, data=data, params=params, headers=headers)
- response = await self.send(
- request, stream=stream, verify=verify, cert=cert, timeout=timeout
- )
- return response
+ return self.send(request, verify=verify, cert=cert, timeout=timeout)
- async def send(
+ def send(
self,
request: Request,
- stream: bool = False,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> Response:
raise NotImplementedError() # pragma: nocover
- async def close(self) -> None:
+ def close(self) -> None:
pass # pragma: nocover
- async def __aenter__(self) -> "Dispatcher":
+ def __enter__(self) -> "Dispatcher":
return self
- async def __aexit__(
+ def __exit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
- await self.close()
+ self.close()
class BaseReader:
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
raise NotImplementedError() # pragma: no cover
+
+ async def run_in_threadpool(
+ self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
+ ) -> typing.Any:
+ raise NotImplementedError() # pragma: no cover
+
+ async def iterate_in_threadpool(self, iterator): # type: ignore
+ class IterationComplete(Exception):
+ pass
+
+ def next_wrapper(iterator): # type: ignore
+ try:
+ return next(iterator)
+ except StopIteration:
+ raise IterationComplete()
+
+ while True:
+ try:
+ yield await self.run_in_threadpool(next_wrapper, iterator)
+ except IterationComplete:
+ break
+
+ def run(
+ self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
+ ) -> typing.Any:
+ raise NotImplementedError() # pragma: no cover
+
+ def iterate(self, async_iterator): # type: ignore
+ while True:
+ try:
+ yield self.run(async_iterator.__anext__)
+ except StopAsyncIteration:
+ break
-import asyncio
import cgi
import email.message
import json as jsonlib
AuthTypes = typing.Union[
typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
- typing.Callable[["Request"], "Request"],
+ typing.Callable[["AsyncRequest"], "AsyncRequest"],
]
-RequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
+AsyncRequestData = typing.Union[dict, bytes, typing.AsyncIterator[bytes]]
-ResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]]
+RequestData = typing.Union[dict, bytes, typing.Iterator[bytes]]
+
+AsyncResponseContent = typing.Union[bytes, typing.AsyncIterator[bytes]]
+
+ResponseContent = typing.Union[bytes, typing.Iterator[bytes]]
class URL:
return f"{class_name}({as_list!r}{encoding_str})"
-class Request:
+class BaseRequest:
def __init__(
self,
method: str,
url: typing.Union[str, URL],
*,
- data: RequestData = b"",
- json: typing.Any = None,
params: QueryParamTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
self._cookies = Cookies(cookies)
self._cookies.set_cookie_header(self)
+ def encode_json(self, json: typing.Any) -> bytes:
+ return jsonlib.dumps(json).encode("utf-8")
+
+ def urlencode_data(self, data: dict) -> bytes:
+ return urlencode(data, doseq=True).encode("utf-8")
+
+ def prepare(self) -> None:
+ content = getattr(self, "content", None) # type: bytes
+ is_streaming = getattr(self, "is_streaming", False)
+
+ auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
+
+ has_user_agent = "user-agent" in self.headers
+ has_accept = "accept" in self.headers
+ has_content_length = (
+ "content-length" in self.headers or "transfer-encoding" in self.headers
+ )
+ has_accept_encoding = "accept-encoding" in self.headers
+
+ if not has_user_agent:
+ auto_headers.append((b"user-agent", b"httpcore"))
+ if not has_accept:
+ auto_headers.append((b"accept", b"*/*"))
+ if not has_content_length:
+ if is_streaming:
+ auto_headers.append((b"transfer-encoding", b"chunked"))
+ elif content:
+ content_length = str(len(content)).encode()
+ auto_headers.append((b"content-length", content_length))
+ if not has_accept_encoding:
+ auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
+
+ for item in reversed(auto_headers):
+ self.headers.raw.insert(0, item)
+
+ @property
+ def cookies(self) -> "Cookies":
+ if not hasattr(self, "_cookies"):
+ self._cookies = Cookies()
+ return self._cookies
+
+ def __repr__(self) -> str:
+ class_name = self.__class__.__name__
+ url = str(self.url)
+ return f"<{class_name}({self.method!r}, {url!r})>"
+
+
+class AsyncRequest(BaseRequest):
+ def __init__(
+ self,
+ method: str,
+ url: typing.Union[str, URL],
+ *,
+ params: QueryParamTypes = None,
+ headers: HeaderTypes = None,
+ cookies: CookieTypes = None,
+ data: AsyncRequestData = b"",
+ json: typing.Any = None,
+ ):
+ super().__init__(
+ method=method, url=url, params=params, headers=headers, cookies=cookies
+ )
+
if json is not None:
- data = jsonlib.dumps(json).encode("utf-8")
+ self.is_streaming = False
+ self.content = self.encode_json(json)
self.headers["Content-Type"] = "application/json"
-
- if isinstance(data, bytes):
+ elif isinstance(data, bytes):
self.is_streaming = False
self.content = data
elif isinstance(data, dict):
self.is_streaming = False
- self.content = urlencode(data, doseq=True).encode("utf-8")
+ self.content = self.urlencode_data(data)
self.headers["Content-Type"] = "application/x-www-form-urlencoded"
else:
+ assert hasattr(data, "__aiter__")
self.is_streaming = True
self.content_aiter = data
elif self.content:
yield self.content
- def prepare(self) -> None:
- auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
- has_content_length = (
- "content-length" in self.headers or "transfer-encoding" in self.headers
+class Request(BaseRequest):
+ def __init__(
+ self,
+ method: str,
+ url: typing.Union[str, URL],
+ *,
+ params: QueryParamTypes = None,
+ headers: HeaderTypes = None,
+ cookies: CookieTypes = None,
+ data: RequestData = b"",
+ json: typing.Any = None,
+ ):
+ super().__init__(
+ method=method, url=url, params=params, headers=headers, cookies=cookies
)
- has_accept_encoding = "accept-encoding" in self.headers
- if not has_content_length:
- if self.is_streaming:
- auto_headers.append((b"transfer-encoding", b"chunked"))
- elif self.content:
- content_length = str(len(self.content)).encode()
- auto_headers.append((b"content-length", content_length))
- if not has_accept_encoding:
- auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
+ if json is not None:
+ self.is_streaming = False
+ self.content = self.encode_json(json)
+ self.headers["Content-Type"] = "application/json"
+ elif isinstance(data, bytes):
+ self.is_streaming = False
+ self.content = data
+ elif isinstance(data, dict):
+ self.is_streaming = False
+ self.content = self.urlencode_data(data)
+ self.headers["Content-Type"] = "application/x-www-form-urlencoded"
+ else:
+ assert hasattr(data, "__iter__")
+ self.is_streaming = True
+ self.content_iter = data
- for item in reversed(auto_headers):
- self.headers.raw.insert(0, item)
+ self.prepare()
- @property
- def cookies(self) -> "Cookies":
- if not hasattr(self, "_cookies"):
- self._cookies = Cookies()
- return self._cookies
+ def read(self) -> bytes:
+ if not hasattr(self, "content"):
+ self.content = b"".join([part for part in self.stream()])
+ return self.content
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- url = str(self.url)
- return f"<{class_name}({self.method!r}, {url!r})>"
+ def stream(self) -> typing.Iterator[bytes]:
+ if self.is_streaming:
+ for part in self.content_iter:
+ yield part
+ elif self.content:
+ yield self.content
-class Response:
+class BaseResponse:
def __init__(
self,
status_code: int,
reason_phrase: str = None,
protocol: str = None,
headers: HeaderTypes = None,
- content: ResponseContent = b"",
+ request: BaseRequest = None,
on_close: typing.Callable = None,
- request: Request = None,
- history: typing.List["Response"] = None,
):
self.status_code = StatusCode.enum_or_int(status_code)
self.reason_phrase = StatusCode.get_reason_phrase(status_code)
self.protocol = protocol
self.headers = Headers(headers)
- if isinstance(content, bytes):
- self.is_closed = True
- self.is_stream_consumed = True
- self._raw_content = content
- else:
- self.is_closed = False
- self.is_stream_consumed = False
- self._raw_stream = content
-
- self.on_close = on_close
self.request = request
- self.history = [] if history is None else list(history)
+ self.on_close = on_close
self.next = None # typing.Optional[typing.Callable]
@property
def content(self) -> bytes:
if not hasattr(self, "_content"):
if hasattr(self, "_raw_content"):
- content = self.decoder.decode(self._raw_content)
+ raw_content = getattr(self, "_raw_content") # type: bytes
+ content = self.decoder.decode(raw_content)
content += self.decoder.flush()
self._content = content
else:
return self._decoder
+ @property
+ def is_redirect(self) -> bool:
+ return StatusCode.is_redirect(self.status_code) and "location" in self.headers
+
+ def raise_for_status(self) -> None:
+ """
+ Raise the `HttpError` if one occurred.
+ """
+ message = (
+ "{0.status_code} {error_type}: {0.reason_phrase} for url: {0.url}\n"
+ "For more information check: https://httpstatuses.com/{0.status_code}"
+ )
+
+ if StatusCode.is_client_error(self.status_code):
+ message = message.format(self, error_type="Client Error")
+ elif StatusCode.is_server_error(self.status_code):
+ message = message.format(self, error_type="Server Error")
+ else:
+ message = ""
+
+ if message:
+ raise HttpError(message)
+
+ def json(self) -> typing.Any:
+ return jsonlib.loads(self.content.decode("utf-8"))
+
+ @property
+ def cookies(self) -> "Cookies":
+ if not hasattr(self, "_cookies"):
+ assert self.request is not None
+ self._cookies = Cookies()
+ self._cookies.extract_cookies(self)
+ return self._cookies
+
+ def __repr__(self) -> str:
+ return f"<Response({self.status_code}, {self.reason_phrase!r})>"
+
+
+class AsyncResponse(BaseResponse):
+ def __init__(
+ self,
+ status_code: int,
+ *,
+ reason_phrase: str = None,
+ protocol: str = None,
+ headers: HeaderTypes = None,
+ content: AsyncResponseContent = b"",
+ on_close: typing.Callable = None,
+ request: AsyncRequest = None,
+ history: typing.List["BaseResponse"] = None,
+ ):
+ super().__init__(
+ status_code=status_code,
+ reason_phrase=reason_phrase,
+ protocol=protocol,
+ headers=headers,
+ request=request,
+ on_close=on_close,
+ )
+
+ self.history = [] if history is None else list(history)
+
+ if isinstance(content, bytes):
+ self.is_closed = True
+ self.is_stream_consumed = True
+ self._raw_content = content
+ else:
+ self.is_closed = False
+ self.is_stream_consumed = False
+ self._raw_stream = content
+
async def read(self) -> bytes:
"""
Read and return the response content.
if self.on_close is not None:
await self.on_close()
- @property
- def is_redirect(self) -> bool:
- return StatusCode.is_redirect(self.status_code) and "location" in self.headers
- def raise_for_status(self) -> None:
- """
- Raise the `HttpError` if one occurred.
- """
- message = (
- "{0.status_code} {error_type}: {0.reason_phrase} for url: {0.url}\n"
- "For more information check: https://httpstatuses.com/{0.status_code}"
+class Response(BaseResponse):
+ def __init__(
+ self,
+ status_code: int,
+ *,
+ reason_phrase: str = None,
+ protocol: str = None,
+ headers: HeaderTypes = None,
+ content: ResponseContent = b"",
+ on_close: typing.Callable = None,
+ request: Request = None,
+ history: typing.List["BaseResponse"] = None,
+ ):
+ super().__init__(
+ status_code=status_code,
+ reason_phrase=reason_phrase,
+ protocol=protocol,
+ headers=headers,
+ request=request,
+ on_close=on_close,
)
- if StatusCode.is_client_error(self.status_code):
- message = message.format(self, error_type="Client Error")
- elif StatusCode.is_server_error(self.status_code):
- message = message.format(self, error_type="Server Error")
- else:
- message = ""
-
- if message:
- raise HttpError(message)
-
- def json(self) -> typing.Any:
- return jsonlib.loads(self.content.decode("utf-8"))
-
- @property
- def cookies(self) -> "Cookies":
- if not hasattr(self, "_cookies"):
- assert self.request is not None
- self._cookies = Cookies()
- self._cookies.extract_cookies(self)
- return self._cookies
-
- def __repr__(self) -> str:
- return f"<Response({self.status_code}, {self.reason_phrase!r})>"
-
-
-class SyncResponse:
- """
- A thread-synchronous response. This class proxies onto a `Response`
- instance, providing standard synchronous interfaces where required.
- """
-
- def __init__(self, response: Response, loop: asyncio.AbstractEventLoop):
- self._response = response
- self._loop = loop
-
- @property
- def status_code(self) -> int:
- return self._response.status_code
-
- @property
- def reason_phrase(self) -> str:
- return self._response.reason_phrase
-
- @property
- def protocol(self) -> typing.Optional[str]:
- return self._response.protocol
-
- @property
- def url(self) -> typing.Optional[URL]:
- return self._response.url
-
- @property
- def request(self) -> typing.Optional[Request]:
- return self._response.request
-
- @property
- def headers(self) -> Headers:
- return self._response.headers
-
- @property
- def content(self) -> bytes:
- return self._response.content
-
- @property
- def text(self) -> str:
- return self._response.text
-
- @property
- def encoding(self) -> str:
- return self._response.encoding
-
- @property
- def is_redirect(self) -> bool:
- return self._response.is_redirect
-
- def raise_for_status(self) -> None:
- return self._response.raise_for_status()
+ self.history = [] if history is None else list(history)
- def json(self) -> typing.Any:
- return self._response.json()
+ if isinstance(content, bytes):
+ self.is_closed = True
+ self.is_stream_consumed = True
+ self._raw_content = content
+ else:
+ self.is_closed = False
+ self.is_stream_consumed = False
+ self._raw_stream = content
def read(self) -> bytes:
- return self._loop.run_until_complete(self._response.read())
+ """
+ Read and return the response content.
+ """
+ if not hasattr(self, "_content"):
+ self._content = b"".join([part for part in self.stream()])
+ return self._content
def stream(self) -> typing.Iterator[bytes]:
- inner = self._response.stream()
- while True:
- try:
- yield self._loop.run_until_complete(inner.__anext__())
- except StopAsyncIteration:
- break
+ """
+ A byte-iterator over the decoded response content.
+ This allows us to handle gzip, deflate, and brotli encoded responses.
+ """
+ if hasattr(self, "_content"):
+ yield self._content
+ else:
+ for chunk in self.raw():
+ yield self.decoder.decode(chunk)
+ yield self.decoder.flush()
def raw(self) -> typing.Iterator[bytes]:
- inner = self._response.raw()
- while True:
- try:
- yield self._loop.run_until_complete(inner.__anext__())
- except StopAsyncIteration:
- break
-
- def close(self) -> None:
- return self._loop.run_until_complete(self._response.close())
+ """
+ A byte-iterator over the raw response content.
+ """
+ if hasattr(self, "_raw_content"):
+ yield self._raw_content
+ else:
+ if self.is_stream_consumed:
+ raise StreamConsumed()
+ if self.is_closed:
+ raise ResponseClosed()
- @property
- def cookies(self) -> "Cookies":
- return self._response.cookies
+ self.is_stream_consumed = True
+ for part in self._raw_stream:
+ yield part
+ self.close()
- def __repr__(self) -> str:
- return f"<Response({self.status_code}, {self.reason_phrase!r})>"
+ def close(self) -> None:
+ """
+ Close the response and release the connection.
+ Automatically called if the response body is read to completion.
+ """
+ if not self.is_closed:
+ self.is_closed = True
+ if self.on_close is not None:
+ self.on_close()
class Cookies(MutableMapping):
else:
self.jar = cookies
- def extract_cookies(self, response: Response) -> None:
+ def extract_cookies(self, response: BaseResponse) -> None:
"""
Loads any cookies based on the response `Set-Cookie` headers.
"""
self.jar.extract_cookies(urlib_response, urllib_request) # type: ignore
- def set_cookie_header(self, request: Request) -> None:
+ def set_cookie_header(self, request: BaseRequest) -> None:
"""
Sets an appropriate 'Cookie:' HTTP header on the `Request`.
"""
for use with `CookieJar` operations.
"""
- def __init__(self, request: Request) -> None:
+ def __init__(self, request: BaseRequest) -> None:
super().__init__(
url=str(request.url),
headers=dict(request.headers),
for use with `CookieJar` operations.
"""
- def __init__(self, response: Response):
+ def __init__(self, response: BaseResponse):
self.response = response
def info(self) -> email.message.Message:
from httpcore import (
URL,
+ AsyncDispatcher,
+ AsyncRequest,
+ AsyncResponse,
CertTypes,
Client,
- Dispatcher,
- Request,
- Response,
TimeoutTypes,
VerifyTypes,
)
-class MockDispatch(Dispatcher):
+class MockDispatch(AsyncDispatcher):
async def send(
self,
- request: Request,
- stream: bool = False,
+ request: AsyncRequest,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
- ) -> Response:
+ ) -> AsyncResponse:
body = json.dumps({"auth": request.headers.get("Authorization")}).encode()
- return Response(200, content=body, request=request)
+ return AsyncResponse(200, content=body, request=request)
def test_basic_auth():
from httpcore import (
URL,
+ AsyncDispatcher,
+ AsyncRequest,
+ AsyncResponse,
CertTypes,
Client,
Cookies,
- Dispatcher,
- Request,
- Response,
TimeoutTypes,
VerifyTypes,
)
-class MockDispatch(Dispatcher):
+class MockDispatch(AsyncDispatcher):
async def send(
self,
- request: Request,
- stream: bool = False,
+ request: AsyncRequest,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
- ) -> Response:
+ ) -> AsyncResponse:
if request.url.path.startswith("/echo_cookies"):
body = json.dumps({"cookies": request.headers.get("Cookie")}).encode()
- return Response(200, content=body, request=request)
+ return AsyncResponse(200, content=body, request=request)
elif request.url.path.startswith("/set_cookie"):
headers = {"set-cookie": "example-name=example-value"}
- return Response(200, headers=headers, request=request)
+ return AsyncResponse(200, headers=headers, request=request)
def test_set_cookie():
from httpcore import (
URL,
AsyncClient,
+ AsyncDispatcher,
+ AsyncRequest,
+ AsyncResponse,
CertTypes,
- Dispatcher,
RedirectBodyUnavailable,
RedirectLoop,
Request,
)
-class MockDispatch(Dispatcher):
+class MockDispatch(AsyncDispatcher):
async def send(
self,
- request: Request,
- stream: bool = False,
+ request: AsyncRequest,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
- ) -> Response:
+ ) -> AsyncResponse:
if request.url.path == "/redirect_301":
status_code = codes.MOVED_PERMANENTLY
headers = {"location": "https://example.org/"}
- return Response(status_code, headers=headers, request=request)
+ return AsyncResponse(status_code, headers=headers, request=request)
elif request.url.path == "/redirect_302":
status_code = codes.FOUND
headers = {"location": "https://example.org/"}
- return Response(status_code, headers=headers, request=request)
+ return AsyncResponse(status_code, headers=headers, request=request)
elif request.url.path == "/redirect_303":
status_code = codes.SEE_OTHER
headers = {"location": "https://example.org/"}
- return Response(status_code, headers=headers, request=request)
+ return AsyncResponse(status_code, headers=headers, request=request)
elif request.url.path == "/relative_redirect":
headers = {"location": "/"}
- return Response(codes.SEE_OTHER, headers=headers, request=request)
+ return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
elif request.url.path == "/no_scheme_redirect":
headers = {"location": "//example.org/"}
- return Response(codes.SEE_OTHER, headers=headers, request=request)
+ return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
elif request.url.path == "/multiple_redirects":
params = parse_qs(request.url.query)
if redirect_count:
location += "?count=" + str(redirect_count)
headers = {"location": location} if count else {}
- return Response(code, headers=headers, request=request)
+ return AsyncResponse(code, headers=headers, request=request)
if request.url.path == "/redirect_loop":
headers = {"location": "/redirect_loop"}
- return Response(codes.SEE_OTHER, headers=headers, request=request)
+ return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
elif request.url.path == "/cross_domain":
headers = {"location": "https://example.org/cross_domain_target"}
- return Response(codes.SEE_OTHER, headers=headers, request=request)
+ return AsyncResponse(codes.SEE_OTHER, headers=headers, request=request)
elif request.url.path == "/cross_domain_target":
headers = dict(request.headers.items())
content = json.dumps({"headers": headers}).encode()
- return Response(codes.OK, content=content, request=request)
+ return AsyncResponse(codes.OK, content=content, request=request)
elif request.url.path == "/redirect_body":
await request.read()
headers = {"location": "/redirect_body_target"}
- return Response(codes.PERMANENT_REDIRECT, headers=headers, request=request)
+ return AsyncResponse(
+ codes.PERMANENT_REDIRECT, headers=headers, request=request
+ )
elif request.url.path == "/redirect_body_target":
content = await request.read()
body = json.dumps({"body": content.decode()}).encode()
- return Response(codes.OK, content=body, request=request)
+ return AsyncResponse(codes.OK, content=body, request=request)
- return Response(codes.OK, content=b"Hello, world!", request=request)
+ return AsyncResponse(codes.OK, content=b"Hello, world!", request=request)
@pytest.mark.asyncio
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
+ await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://127.0.0.1:8000/")
+ await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
"""
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
+ await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://localhost:8000/")
+ await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 2
async with httpcore.ConnectionPool(pool_limits=pool_limits) as http:
response = await http.request("GET", "http://127.0.0.1:8000/")
+ await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
response = await http.request("GET", "http://localhost:8000/")
+ await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 1
A streaming request should hold the connection open until the response is read.
"""
async with httpcore.ConnectionPool() as http:
- response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+ response = await http.request("GET", "http://127.0.0.1:8000/")
assert len(http.active_connections) == 1
assert len(http.keepalive_connections) == 0
Multiple conncurrent requests should open multiple conncurrent connections.
"""
async with httpcore.ConnectionPool() as http:
- response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+ response_a = await http.request("GET", "http://127.0.0.1:8000/")
assert len(http.active_connections) == 1
assert len(http.keepalive_connections) == 0
- response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+ response_b = await http.request("GET", "http://127.0.0.1:8000/")
assert len(http.active_connections) == 2
assert len(http.keepalive_connections) == 0
headers = [(b"connection", b"close")]
async with httpcore.ConnectionPool() as http:
response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
+ await response.read()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 0
A standard close should keep the connection open.
"""
async with httpcore.ConnectionPool() as http:
- response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+ response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()
await response.close()
assert len(http.active_connections) == 0
A premature close should close the connection.
"""
async with httpcore.ConnectionPool() as http:
- response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+ response = await http.request("GET", "http://127.0.0.1:8000/")
await response.close()
assert len(http.active_connections) == 0
assert len(http.keepalive_connections) == 0
async def test_get(server):
conn = HTTPConnection(origin="http://127.0.0.1:8000/")
response = await conn.request("GET", "http://127.0.0.1:8000/")
+ await response.read()
assert response.status_code == 200
assert response.content == b"Hello, world!"
"""
conn = HTTPConnection(origin="https://127.0.0.1:8001/", verify=False)
response = await conn.request("GET", "https://127.0.0.1:8001/")
+ await response.read()
assert response.status_code == 200
assert response.content == b"Hello, world!"
"""
conn = HTTPConnection(origin="https://127.0.0.1:8001/")
response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False)
+ await response.read()
assert response.status_code == 200
assert response.content == b"Hello, world!"
--- /dev/null
+import json
+
+import pytest
+
+from httpcore import (
+ CertTypes,
+ Client,
+ Dispatcher,
+ Request,
+ Response,
+ TimeoutTypes,
+ VerifyTypes,
+)
+
+
+def streaming_body():
+ for part in [b"Hello", b", ", b"world!"]:
+ yield part
+
+
+class MockDispatch(Dispatcher):
+ def send(
+ self,
+ request: Request,
+ verify: VerifyTypes = None,
+ cert: CertTypes = None,
+ timeout: TimeoutTypes = None,
+ ) -> Response:
+ if request.url.path == "/streaming_response":
+ return Response(200, content=streaming_body(), request=request)
+ elif request.url.path == "/echo_request_body":
+ content = request.read()
+ return Response(200, content=content, request=request)
+ elif request.url.path == "/echo_request_body_streaming":
+ content = b"".join([part for part in request.stream()])
+ return Response(200, content=content, request=request)
+ else:
+ body = json.dumps({"hello": "world"}).encode()
+ return Response(200, content=body, request=request)
+
+
+def test_threaded_dispatch():
+ """
+ Use a syncronous 'Dispatcher' class with the client.
+ Calls to the dispatcher will end up running within a thread pool.
+ """
+ url = "https://example.org/"
+ with Client(dispatch=MockDispatch()) as client:
+ response = client.get(url)
+
+ assert response.status_code == 200
+ assert response.json() == {"hello": "world"}
+
+
+def test_threaded_streaming_response():
+ url = "https://example.org/streaming_response"
+ with Client(dispatch=MockDispatch()) as client:
+ response = client.get(url)
+
+ assert response.status_code == 200
+ assert response.text == "Hello, world!"
+
+
+def test_threaded_streaming_request():
+ url = "https://example.org/echo_request_body"
+ with Client(dispatch=MockDispatch()) as client:
+ response = client.post(url, data=streaming_body())
+
+ assert response.status_code == 200
+ assert response.text == "Hello, world!"
+
+
+def test_threaded_request_body():
+ url = "https://example.org/echo_request_body"
+ with Client(dispatch=MockDispatch()) as client:
+ response = client.post(url, data=b"Hello, world!")
+
+ assert response.status_code == 200
+ assert response.text == "Hello, world!"
+
+
+def test_threaded_request_body_streaming():
+ url = "https://example.org/echo_request_body_streaming"
+ with Client(dispatch=MockDispatch()) as client:
+ response = client.post(url, data=b"Hello, world!")
+
+ assert response.status_code == 200
+ assert response.text == "Hello, world!"
+
+
+def test_dispatch_class():
+ """
+ Use a syncronous 'Dispatcher' class directly.
+ """
+ url = "https://example.org/"
+ with MockDispatch() as dispatcher:
+ response = dispatcher.request("GET", url)
+
+ assert response.status_code == 200
+ assert response.json() == {"hello": "world"}
def test_no_content():
request = httpcore.Request("GET", "http://example.org")
- request.prepare()
- assert request.headers == httpcore.Headers(
- [(b"accept-encoding", b"deflate, gzip, br")]
- )
+ assert "Content-Length" not in request.headers
def test_content_length_header():
request = httpcore.Request("POST", "http://example.org", data=b"test 123")
- request.prepare()
- assert request.headers == httpcore.Headers(
- [
- (b"content-length", b"8"),
- (b"accept-encoding", b"deflate, gzip, br"),
- ]
- )
+ assert request.headers["Content-Length"] == "8"
def test_url_encoded_data():
- request = httpcore.Request("POST", "http://example.org", data={"test": "123"})
- request.prepare()
- assert request.headers == httpcore.Headers(
- [
- (b"content-length", b"8"),
- (b"accept-encoding", b"deflate, gzip, br"),
- (b"content-type", b"application/x-www-form-urlencoded"),
- ]
- )
- assert request.content == b"test=123"
+ for RequestClass in (httpcore.Request, httpcore.AsyncRequest):
+ request = RequestClass("POST", "http://example.org", data={"test": "123"})
+ assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
+ assert request.content == b"test=123"
+
+
+def test_json_encoded_data():
+ for RequestClass in (httpcore.Request, httpcore.AsyncRequest):
+ request = RequestClass("POST", "http://example.org", json={"test": 123})
+ assert request.headers["Content-Type"] == "application/json"
+ assert request.content == b'{"test": 123}'
def test_transfer_encoding_header():
- async def streaming_body(data):
+ def streaming_body(data):
yield data # pragma: nocover
data = streaming_body(b"test 123")
request = httpcore.Request("POST", "http://example.org", data=data)
- request.prepare()
- assert request.headers == httpcore.Headers(
- [
- (b"transfer-encoding", b"chunked"),
- (b"accept-encoding", b"deflate, gzip, br"),
- ]
- )
+ assert "Content-Length" not in request.headers
+ assert request.headers["Transfer-Encoding"] == "chunked"
def test_override_host_header():
- headers = [(b"host", b"1.2.3.4:80")]
+ headers = {"host": "1.2.3.4:80"}
request = httpcore.Request("GET", "http://example.org", headers=headers)
- request.prepare()
- assert request.headers == httpcore.Headers(
- [(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")]
- )
+ assert request.headers["Host"] == "1.2.3.4:80"
def test_override_accept_encoding_header():
- headers = [(b"accept-encoding", b"identity")]
+ headers = {"Accept-Encoding": "identity"}
request = httpcore.Request("GET", "http://example.org", headers=headers)
- request.prepare()
- assert request.headers == httpcore.Headers(
- [(b"accept-encoding", b"identity")]
- )
+ assert request.headers["Accept-Encoding"] == "identity"
def test_override_content_length_header():
- async def streaming_body(data):
+ def streaming_body(data):
yield data # pragma: nocover
data = streaming_body(b"test 123")
- headers = [(b"content-length", b"8")]
+ headers = {"Content-Length": "8"}
request = httpcore.Request("POST", "http://example.org", data=data, headers=headers)
- request.prepare()
- assert request.headers == httpcore.Headers(
- [
- (b"accept-encoding", b"deflate, gzip, br"),
- (b"content-length", b"8"),
- ]
- )
+ assert request.headers["Content-Length"] == "8"
def test_url():
import httpcore
-async def streaming_body():
+def streaming_body():
+ yield b"Hello, "
+ yield b"world!"
+
+
+async def async_streaming_body():
yield b"Hello, "
yield b"world!"
assert response.encoding == "iso-8859-1"
-@pytest.mark.asyncio
-async def test_read_response():
+def test_read_response():
response = httpcore.Response(200, content=b"Hello, world!")
assert response.status_code == 200
assert response.encoding == "ascii"
assert response.is_closed
- content = await response.read()
+ content = response.read()
assert content == b"Hello, world!"
assert response.content == b"Hello, world!"
assert response.is_closed
-@pytest.mark.asyncio
-async def test_raw_interface():
+def test_raw_interface():
response = httpcore.Response(200, content=b"Hello, world!")
raw = b""
- async for part in response.raw():
+ for part in response.raw():
raw += part
assert raw == b"Hello, world!"
-@pytest.mark.asyncio
-async def test_stream_interface():
+def test_stream_interface():
response = httpcore.Response(200, content=b"Hello, world!")
content = b""
- async for part in response.stream():
+ for part in response.stream():
content += part
assert content == b"Hello, world!"
@pytest.mark.asyncio
-async def test_stream_interface_after_read():
+async def test_async_stream_interface():
+ response = httpcore.AsyncResponse(200, content=b"Hello, world!")
+
+ content = b""
+ async for part in response.stream():
+ content += part
+ assert content == b"Hello, world!"
+
+
+def test_stream_interface_after_read():
response = httpcore.Response(200, content=b"Hello, world!")
+ response.read()
+
+ content = b""
+ for part in response.stream():
+ content += part
+ assert content == b"Hello, world!"
+
+
+@pytest.mark.asyncio
+async def test_async_stream_interface_after_read():
+ response = httpcore.AsyncResponse(200, content=b"Hello, world!")
+
await response.read()
content = b""
assert content == b"Hello, world!"
-@pytest.mark.asyncio
-async def test_streaming_response():
+def test_streaming_response():
response = httpcore.Response(200, content=streaming_body())
assert response.status_code == 200
assert not response.is_closed
- content = await response.read()
+ content = response.read()
assert content == b"Hello, world!"
assert response.content == b"Hello, world!"
@pytest.mark.asyncio
-async def test_cannot_read_after_stream_consumed():
+async def test_async_streaming_response():
+ response = httpcore.AsyncResponse(200, content=async_streaming_body())
+
+ assert response.status_code == 200
+ assert not response.is_closed
+
+ content = await response.read()
+
+ assert content == b"Hello, world!"
+ assert response.content == b"Hello, world!"
+ assert response.is_closed
+
+
+def test_cannot_read_after_stream_consumed():
response = httpcore.Response(200, content=streaming_body())
+ content = b""
+ for part in response.stream():
+ content += part
+
+ with pytest.raises(httpcore.StreamConsumed):
+ response.read()
+
+
+@pytest.mark.asyncio
+async def test_async_cannot_read_after_stream_consumed():
+ response = httpcore.AsyncResponse(200, content=async_streaming_body())
+
content = b""
async for part in response.stream():
content += part
await response.read()
-@pytest.mark.asyncio
-async def test_cannot_read_after_response_closed():
+def test_cannot_read_after_response_closed():
response = httpcore.Response(200, content=streaming_body())
+ response.close()
+
+ with pytest.raises(httpcore.ResponseClosed):
+ response.read()
+
+
+@pytest.mark.asyncio
+async def test_async_cannot_read_after_response_closed():
+ response = httpcore.AsyncResponse(200, content=async_streaming_body())
+
await response.close()
with pytest.raises(httpcore.ResponseClosed):
assert response.reason_phrase == "OK"
+@threadpool
+def test_post_byte_iterator(server):
+ def data():
+ yield b"Hello"
+ yield b", "
+ yield b"world!"
+
+ response = httpcore.post("http://127.0.0.1:8000/", data=data())
+ assert response.status_code == 200
+ assert response.reason_phrase == "OK"
+
+
@threadpool
def test_options(server):
response = httpcore.options("http://127.0.0.1:8000/")
assert response.content == body
-@pytest.mark.asyncio
-async def test_streaming():
+def test_streaming():
body = b"test 123"
compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
- async def compress(body):
+ def compress(body):
yield compressor.compress(body)
yield compressor.flush()
headers = [(b"Content-Encoding", b"gzip")]
response = httpcore.Response(200, headers=headers, content=compress(body))
assert not hasattr(response, "body")
- assert await response.read() == body
+ assert response.read() == body
@pytest.mark.parametrize("header_value", (b"deflate", b"gzip", b"br"))