]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Support cookie persistence 73/head
authorTom Christie <tom@tomchristie.com>
Fri, 17 May 2019 11:41:36 +0000 (12:41 +0100)
committerTom Christie <tom@tomchristie.com>
Fri, 17 May 2019 11:41:36 +0000 (12:41 +0100)
httpcore/client.py
httpcore/models.py
tests/client/test_cookie_handling.py

index 8361fe7325171b047e76e4f23c3158aa7abe4a72..3da58e6f11105e910e0c49f261ffd016ad30c8b5 100644 (file)
@@ -1,6 +1,5 @@
 import asyncio
 import typing
-from http.cookiejar import CookieJar
 from types import TracebackType
 
 from .auth import HTTPBasicAuth
@@ -19,6 +18,8 @@ from .interfaces import ConcurrencyBackend, Dispatcher
 from .models import (
     URL,
     AuthTypes,
+    Cookies,
+    CookieTypes,
     Headers,
     HeaderTypes,
     QueryParamTypes,
@@ -35,6 +36,7 @@ class AsyncClient:
     def __init__(
         self,
         auth: AuthTypes = None,
+        cookies: CookieTypes = None,
         ssl: SSLConfig = DEFAULT_SSL_CONFIG,
         timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
@@ -48,6 +50,7 @@ class AsyncClient:
             )
 
         self.auth = auth
+        self.cookies = Cookies(cookies)
         self.max_redirects = max_redirects
         self.dispatch = dispatch
 
@@ -57,7 +60,7 @@ class AsyncClient:
         *,
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
@@ -83,7 +86,7 @@ class AsyncClient:
         *,
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
@@ -109,7 +112,7 @@ class AsyncClient:
         *,
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = False,  #  Note: Differs to usual default.
@@ -136,7 +139,7 @@ class AsyncClient:
         data: RequestData = b"",
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
@@ -164,7 +167,7 @@ class AsyncClient:
         data: RequestData = b"",
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
@@ -192,7 +195,7 @@ class AsyncClient:
         data: RequestData = b"",
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
@@ -220,7 +223,7 @@ class AsyncClient:
         data: RequestData = b"",
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
@@ -249,7 +252,7 @@ class AsyncClient:
         data: RequestData = b"",
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
@@ -262,7 +265,7 @@ class AsyncClient:
             data=data,
             query_params=query_params,
             headers=headers,
-            cookies=cookies,
+            cookies=self.merge_cookies(cookies),
         )
         self.prepare_request(request)
         response = await self.send(
@@ -278,6 +281,13 @@ class AsyncClient:
     def prepare_request(self, request: Request) -> None:
         request.prepare()
 
+    def merge_cookies(self, cookies: CookieTypes = None) -> typing.Optional[CookieTypes]:
+        if cookies or self.cookies:
+            merged_cookies = Cookies(self.cookies)
+            merged_cookies.update(cookies)
+            return merged_cookies
+        return cookies
+
     async def send(
         self,
         request: Request,
@@ -334,6 +344,7 @@ class AsyncClient:
                 request, stream=stream, ssl=ssl, timeout=timeout
             )
             response.history = list(history)
+            self.cookies.extract_cookies(response)
             history = [response] + history
             if not response.is_redirect:
                 break
@@ -365,7 +376,8 @@ class AsyncClient:
         url = self.redirect_url(request, response)
         headers = self.redirect_headers(request, url)
         content = self.redirect_content(request, method)
-        return Request(method=method, url=url, headers=headers, data=content)
+        cookies = self.merge_cookies(request.cookies)
+        return Request(method=method, url=url, headers=headers, data=content, cookies=cookies)
 
     def redirect_method(self, request: Request, response: Response) -> str:
         """
@@ -466,6 +478,10 @@ class Client:
         )
         self._loop = asyncio.new_event_loop()
 
+    @property
+    def cookies(self) -> Cookies:
+        return self._client.cookies
+
     def request(
         self,
         method: str,
@@ -474,7 +490,7 @@ class Client:
         data: RequestData = b"",
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
@@ -487,7 +503,7 @@ class Client:
             data=data,
             query_params=query_params,
             headers=headers,
-            cookies=cookies,
+            cookies=self._client.merge_cookies(cookies),
         )
         self.prepare_request(request)
         response = self.send(
@@ -506,7 +522,7 @@ class Client:
         *,
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
@@ -531,7 +547,7 @@ class Client:
         *,
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
@@ -556,7 +572,7 @@ class Client:
         *,
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = False,  #  Note: Differs to usual default.
@@ -582,7 +598,7 @@ class Client:
         data: RequestData = b"",
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
@@ -609,7 +625,7 @@ class Client:
         data: RequestData = b"",
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
@@ -636,7 +652,7 @@ class Client:
         data: RequestData = b"",
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
@@ -663,7 +679,7 @@ class Client:
         data: RequestData = b"",
         query_params: QueryParamTypes = None,
         headers: HeaderTypes = None,
-        cookies: CookieJar = None,
+        cookies: CookieTypes = None,
         stream: bool = False,
         auth: AuthTypes = None,
         allow_redirects: bool = True,
index 7fc4e65f827185a62ac7a67fc863c0ffe2ecfa86..d5b8e7f2ffa32e5a935168c0051eff9a1b38a530 100644 (file)
@@ -488,8 +488,8 @@ class Request:
         self.url = URL(url, query_params=query_params)
         self.headers = Headers(headers)
         if cookies:
-            cookies = Cookies(cookies)
-            cookies.set_cookie_header(self)
+            self._cookies = Cookies(cookies)
+            self._cookies.set_cookie_header(self)
 
         if isinstance(data, bytes):
             self.is_streaming = False
@@ -547,6 +547,12 @@ class Request:
         for item in reversed(auto_headers):
             self.headers.raw.insert(0, item)
 
+    @property
+    def cookies(self) -> "Cookies":
+        if not hasattr(self, "_cookies"):
+            self._cookies = Cookies()
+        return self._cookies
+
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
         url = str(self.url)
@@ -874,7 +880,9 @@ class Cookies(MutableMapping):
                 for key, value in cookies.items():
                     self.set(key, value)
         elif isinstance(cookies, Cookies):
-            self.jar = cookies.jar
+            self.jar = CookieJar()
+            for cookie in cookies.jar:
+                self.jar.set_cookie(cookie)
         else:
             self.jar = cookies
 
index 64fd9fc5f6b6caf1fbdc98e7bee1ea98079781a6..7fe057d61039ab48ccf456440c1aeb41f6854959 100644 (file)
@@ -104,3 +104,23 @@ def test_get_cookie():
 
     assert response.status_code == 200
     assert response.cookies["example-name"] == "example-value"
+    assert client.cookies["example-name"] == "example-value"
+
+
+def test_cookie_persistence():
+    """
+    Ensure that Client instances persist cookies between requests.
+    """
+    with Client(dispatch=MockDispatch()) as client:
+        response = client.get("http://example.org/echo_cookies")
+        assert response.status_code == 200
+        assert json.loads(response.text) == {"cookies": None}
+
+        response = client.get("http://example.org/set_cookie")
+        assert response.status_code == 200
+        assert response.cookies["example-name"] == "example-value"
+        assert client.cookies["example-name"] == "example-value"
+
+        response = client.get("http://example.org/echo_cookies")
+        assert response.status_code == 200
+        assert json.loads(response.text) == {"cookies": "example-name=example-value"}