From: Tom Christie Date: Fri, 17 May 2019 11:41:36 +0000 (+0100) Subject: Support cookie persistence X-Git-Tag: 0.3.1~18^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f63148aa7326de1ac2629b70f6c77e6a496df3f2;p=thirdparty%2Fhttpx.git Support cookie persistence --- diff --git a/httpcore/client.py b/httpcore/client.py index 8361fe73..3da58e6f 100644 --- a/httpcore/client.py +++ b/httpcore/client.py @@ -1,6 +1,5 @@ import asyncio import typing -from http.cookiejar import CookieJar from types import TracebackType from .auth import HTTPBasicAuth @@ -19,6 +18,8 @@ from .interfaces import ConcurrencyBackend, Dispatcher from .models import ( URL, AuthTypes, + Cookies, + CookieTypes, Headers, HeaderTypes, QueryParamTypes, @@ -35,6 +36,7 @@ class AsyncClient: def __init__( self, auth: AuthTypes = None, + cookies: CookieTypes = None, ssl: SSLConfig = DEFAULT_SSL_CONFIG, timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, @@ -48,6 +50,7 @@ class AsyncClient: ) self.auth = auth + self.cookies = Cookies(cookies) self.max_redirects = max_redirects self.dispatch = dispatch @@ -57,7 +60,7 @@ class AsyncClient: *, query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, @@ -83,7 +86,7 @@ class AsyncClient: *, query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, @@ -109,7 +112,7 @@ class AsyncClient: *, query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = False, #  Note: Differs to usual default. @@ -136,7 +139,7 @@ class AsyncClient: data: RequestData = b"", query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, @@ -164,7 +167,7 @@ class AsyncClient: data: RequestData = b"", query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, @@ -192,7 +195,7 @@ class AsyncClient: data: RequestData = b"", query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, @@ -220,7 +223,7 @@ class AsyncClient: data: RequestData = b"", query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, @@ -249,7 +252,7 @@ class AsyncClient: data: RequestData = b"", query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, @@ -262,7 +265,7 @@ class AsyncClient: data=data, query_params=query_params, headers=headers, - cookies=cookies, + cookies=self.merge_cookies(cookies), ) self.prepare_request(request) response = await self.send( @@ -278,6 +281,13 @@ class AsyncClient: def prepare_request(self, request: Request) -> None: request.prepare() + def merge_cookies(self, cookies: CookieTypes = None) -> typing.Optional[CookieTypes]: + if cookies or self.cookies: + merged_cookies = Cookies(self.cookies) + merged_cookies.update(cookies) + return merged_cookies + return cookies + async def send( self, request: Request, @@ -334,6 +344,7 @@ class AsyncClient: request, stream=stream, ssl=ssl, timeout=timeout ) response.history = list(history) + self.cookies.extract_cookies(response) history = [response] + history if not response.is_redirect: break @@ -365,7 +376,8 @@ class AsyncClient: url = self.redirect_url(request, response) headers = self.redirect_headers(request, url) content = self.redirect_content(request, method) - return Request(method=method, url=url, headers=headers, data=content) + cookies = self.merge_cookies(request.cookies) + return Request(method=method, url=url, headers=headers, data=content, cookies=cookies) def redirect_method(self, request: Request, response: Response) -> str: """ @@ -466,6 +478,10 @@ class Client: ) self._loop = asyncio.new_event_loop() + @property + def cookies(self) -> Cookies: + return self._client.cookies + def request( self, method: str, @@ -474,7 +490,7 @@ class Client: data: RequestData = b"", query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, @@ -487,7 +503,7 @@ class Client: data=data, query_params=query_params, headers=headers, - cookies=cookies, + cookies=self._client.merge_cookies(cookies), ) self.prepare_request(request) response = self.send( @@ -506,7 +522,7 @@ class Client: *, query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, @@ -531,7 +547,7 @@ class Client: *, query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, @@ -556,7 +572,7 @@ class Client: *, query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = False, #  Note: Differs to usual default. @@ -582,7 +598,7 @@ class Client: data: RequestData = b"", query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, @@ -609,7 +625,7 @@ class Client: data: RequestData = b"", query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, @@ -636,7 +652,7 @@ class Client: data: RequestData = b"", query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, @@ -663,7 +679,7 @@ class Client: data: RequestData = b"", query_params: QueryParamTypes = None, headers: HeaderTypes = None, - cookies: CookieJar = None, + cookies: CookieTypes = None, stream: bool = False, auth: AuthTypes = None, allow_redirects: bool = True, diff --git a/httpcore/models.py b/httpcore/models.py index 7fc4e65f..d5b8e7f2 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -488,8 +488,8 @@ class Request: self.url = URL(url, query_params=query_params) self.headers = Headers(headers) if cookies: - cookies = Cookies(cookies) - cookies.set_cookie_header(self) + self._cookies = Cookies(cookies) + self._cookies.set_cookie_header(self) if isinstance(data, bytes): self.is_streaming = False @@ -547,6 +547,12 @@ class Request: for item in reversed(auto_headers): self.headers.raw.insert(0, item) + @property + def cookies(self) -> "Cookies": + if not hasattr(self, "_cookies"): + self._cookies = Cookies() + return self._cookies + def __repr__(self) -> str: class_name = self.__class__.__name__ url = str(self.url) @@ -874,7 +880,9 @@ class Cookies(MutableMapping): for key, value in cookies.items(): self.set(key, value) elif isinstance(cookies, Cookies): - self.jar = cookies.jar + self.jar = CookieJar() + for cookie in cookies.jar: + self.jar.set_cookie(cookie) else: self.jar = cookies diff --git a/tests/client/test_cookie_handling.py b/tests/client/test_cookie_handling.py index 64fd9fc5..7fe057d6 100644 --- a/tests/client/test_cookie_handling.py +++ b/tests/client/test_cookie_handling.py @@ -104,3 +104,23 @@ def test_get_cookie(): assert response.status_code == 200 assert response.cookies["example-name"] == "example-value" + assert client.cookies["example-name"] == "example-value" + + +def test_cookie_persistence(): + """ + Ensure that Client instances persist cookies between requests. + """ + with Client(dispatch=MockDispatch()) as client: + response = client.get("http://example.org/echo_cookies") + assert response.status_code == 200 + assert json.loads(response.text) == {"cookies": None} + + response = client.get("http://example.org/set_cookie") + assert response.status_code == 200 + assert response.cookies["example-name"] == "example-value" + assert client.cookies["example-name"] == "example-value" + + response = client.get("http://example.org/echo_cookies") + assert response.status_code == 200 + assert json.loads(response.text) == {"cookies": "example-name=example-value"}