]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Finessing interface
authorTom Christie <tom@tomchristie.com>
Tue, 30 Apr 2019 10:26:11 +0000 (11:26 +0100)
committerTom Christie <tom@tomchristie.com>
Tue, 30 Apr 2019 10:26:11 +0000 (11:26 +0100)
httpcore/client.py
httpcore/config.py
httpcore/dispatch/http11.py
httpcore/models.py
httpcore/sync.py
httpcore/utils.py
tests/test_api.py
tests/test_config.py
tests/test_responses.py

index 139a56867997f1ea6afb4ae79e6a94cf969f6c81..5f56281ae98ac69503acf41ebc7df0c52d8e6127 100644 (file)
@@ -15,7 +15,7 @@ from .config import (
     TimeoutConfig,
 )
 from .dispatch.connection_pool import ConnectionPool
-from .models import URL, Request, Response
+from .models import URL, BodyTypes, HeaderTypes, Request, Response, URLTypes
 
 
 class Client:
@@ -37,14 +37,14 @@ class Client:
     async def request(
         self,
         method: str,
-        url: typing.Union[str, URL],
+        url: URLTypes,
         *,
-        body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
-        headers: typing.List[typing.Tuple[bytes, bytes]] = [],
+        body: BodyTypes = b"",
+        headers: HeaderTypes = None,
         stream: bool = False,
         allow_redirects: bool = True,
-        ssl: typing.Optional[SSLConfig] = None,
-        timeout: typing.Optional[TimeoutConfig] = None,
+        ssl: SSLConfig = None,
+        timeout: TimeoutConfig = None,
     ) -> Response:
         request = Request(method, url, headers=headers, body=body)
         self.prepare_request(request)
@@ -59,26 +59,74 @@ class Client:
 
     async def get(
         self,
-        url: typing.Union[str, URL],
+        url: URLTypes,
         *,
-        headers: typing.List[typing.Tuple[bytes, bytes]] = [],
+        headers: HeaderTypes = None,
         stream: bool = False,
-        ssl: typing.Optional[SSLConfig] = None,
-        timeout: typing.Optional[TimeoutConfig] = None,
+        allow_redirects: bool = True,
+        ssl: SSLConfig = None,
+        timeout: TimeoutConfig = None,
+    ) -> Response:
+        return await self.request(
+            "GET",
+            url,
+            headers=headers,
+            stream=stream,
+            allow_redirects=allow_redirects,
+            ssl=ssl,
+            timeout=timeout,
+        )
+
+    async def options(
+        self,
+        url: URLTypes,
+        *,
+        headers: HeaderTypes = None,
+        stream: bool = False,
+        allow_redirects: bool = True,
+        ssl: SSLConfig = None,
+        timeout: TimeoutConfig = None,
+    ) -> Response:
+        return await self.request(
+            "OPTIONS",
+            url,
+            headers=headers,
+            stream=stream,
+            allow_redirects=allow_redirects,
+            ssl=ssl,
+            timeout=timeout,
+        )
+
+    async def head(
+        self,
+        url: URLTypes,
+        *,
+        headers: HeaderTypes = None,
+        stream: bool = False,
+        allow_redirects: bool = False,  #  Note: Differs to usual default.
+        ssl: SSLConfig = None,
+        timeout: TimeoutConfig = None,
     ) -> Response:
         return await self.request(
-            "GET", url, headers=headers, stream=stream, ssl=ssl, timeout=timeout
+            "HEAD",
+            url,
+            headers=headers,
+            stream=stream,
+            allow_redirects=allow_redirects,
+            ssl=ssl,
+            timeout=timeout,
         )
 
     async def post(
         self,
-        url: typing.Union[str, URL],
+        url: URLTypes,
         *,
-        body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
-        headers: typing.List[typing.Tuple[bytes, bytes]] = [],
+        body: BodyTypes = b"",
+        headers: HeaderTypes = None,
         stream: bool = False,
-        ssl: typing.Optional[SSLConfig] = None,
-        timeout: typing.Optional[TimeoutConfig] = None,
+        allow_redirects: bool = True,
+        ssl: SSLConfig = None,
+        timeout: TimeoutConfig = None,
     ) -> Response:
         return await self.request(
             "POST",
@@ -86,6 +134,73 @@ class Client:
             body=body,
             headers=headers,
             stream=stream,
+            allow_redirects=allow_redirects,
+            ssl=ssl,
+            timeout=timeout,
+        )
+
+    async def put(
+        self,
+        url: URLTypes,
+        *,
+        body: BodyTypes = b"",
+        headers: HeaderTypes = None,
+        stream: bool = False,
+        allow_redirects: bool = True,
+        ssl: SSLConfig = None,
+        timeout: TimeoutConfig = None,
+    ) -> Response:
+        return await self.request(
+            "PUT",
+            url,
+            body=body,
+            headers=headers,
+            stream=stream,
+            allow_redirects=allow_redirects,
+            ssl=ssl,
+            timeout=timeout,
+        )
+
+    async def patch(
+        self,
+        url: URLTypes,
+        *,
+        body: BodyTypes = b"",
+        headers: HeaderTypes = None,
+        stream: bool = False,
+        allow_redirects: bool = True,
+        ssl: SSLConfig = None,
+        timeout: TimeoutConfig = None,
+    ) -> Response:
+        return await self.request(
+            "PATCH",
+            url,
+            body=body,
+            headers=headers,
+            stream=stream,
+            allow_redirects=allow_redirects,
+            ssl=ssl,
+            timeout=timeout,
+        )
+
+    async def delete(
+        self,
+        url: URLTypes,
+        *,
+        body: BodyTypes = b"",
+        headers: HeaderTypes = None,
+        stream: bool = False,
+        allow_redirects: bool = True,
+        ssl: SSLConfig = None,
+        timeout: TimeoutConfig = None,
+    ) -> Response:
+        return await self.request(
+            "DELETE",
+            url,
+            body=body,
+            headers=headers,
+            stream=stream,
+            allow_redirects=allow_redirects,
             ssl=ssl,
             timeout=timeout,
         )
@@ -99,14 +214,19 @@ class Client:
         *,
         stream: bool = False,
         allow_redirects: bool = True,
-        ssl: typing.Optional[SSLConfig] = None,
-        timeout: typing.Optional[TimeoutConfig] = None,
+        ssl: SSLConfig = None,
+        timeout: TimeoutConfig = None,
     ) -> Response:
-        options = {"stream": stream}  # type: typing.Dict[str, typing.Any]
+        options = {
+            "stream": stream,
+            "allow_redirects": allow_redirects,
+        }  # type: typing.Dict[str, typing.Any]
+
         if ssl is not None:
             options["ssl"] = ssl
         if timeout is not None:
             options["timeout"] = timeout
+
         return await self.adapter.send(request, **options)
 
     async def close(self) -> None:
index ef24a8b1a0046893140e1195d247fbe09fd79e7c..82fd125ff4fe018c1933fa3c147c4ac9570b3f4e 100644 (file)
@@ -71,6 +71,8 @@ class SSLConfig:
 
         context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
 
+        context.verify_mode = ssl.CERT_REQUIRED
+
         context.options |= ssl.OP_NO_SSLv2
         context.options |= ssl.OP_NO_SSLv3
         context.options |= ssl.OP_NO_COMPRESSION
index 8128dc18fd8027e5d7c925273875b8ed941625ca..62cb2facb3f254ef8ae10d8049963826805b6f51 100644 (file)
@@ -75,14 +75,14 @@ class HTTP11Connection(Adapter):
             event = await self._receive_event(timeout)
 
         assert isinstance(event, h11.Response)
-        reason = event.reason.decode("latin1")
+        reason_phrase = event.reason.decode("latin1")
         status_code = event.status_code
         headers = event.headers
         body = self._body_iter(timeout)
 
         response = Response(
             status_code=status_code,
-            reason=reason,
+            reason_phrase=reason_phrase,
             protocol="HTTP/1.1",
             headers=headers,
             body=body,
index 4b97211c1fcbc3757fd2116b0a0945c3467bb0b8..97e18dc56c58b81ca4ae4fb0d67d9226ec1803de 100644 (file)
@@ -1,4 +1,3 @@
-import http
 import typing
 from urllib.parse import urlsplit
 
@@ -11,12 +10,27 @@ from .decoders import (
     MultiDecoder,
 )
 from .exceptions import ResponseClosed, StreamConsumed
-from .utils import normalize_header_key, normalize_header_value
+from .status_codes import codes
+from .utils import get_reason_phrase, normalize_header_key, normalize_header_value
+
+URLTypes = typing.Union["URL", str]
+
+HeaderTypes = typing.Union[
+    "Headers",
+    typing.Dict[typing.AnyStr, typing.AnyStr],
+    typing.List[typing.Tuple[typing.AnyStr, typing.AnyStr]],
+]
+
+BodyTypes = typing.Union[bytes, typing.AsyncIterator[bytes]]
 
 
 class URL:
-    def __init__(self, url: str = "") -> None:
-        self.components = urlsplit(url)
+    def __init__(self, url: URLTypes) -> None:
+        if isinstance(url, str):
+            self.components = urlsplit(url)
+        else:
+            self.components = url.components
+
         if not self.components.scheme:
             raise ValueError("No scheme included in URL.")
         if self.components.scheme not in ("http", "https"):
@@ -106,13 +120,6 @@ class Origin:
         return hash((self.is_ssl, self.hostname, self.port))
 
 
-HeaderTypes = typing.Union[
-    "Headers",
-    typing.Dict[typing.AnyStr, typing.AnyStr],
-    typing.List[typing.Tuple[typing.AnyStr, typing.AnyStr]],
-]
-
-
 class Headers(typing.MutableMapping[str, str]):
     """
     A case-insensitive multidict.
@@ -239,7 +246,7 @@ class Request:
         url: typing.Union[str, URL],
         *,
         headers: HeaderTypes = None,
-        body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
+        body: BodyTypes = b"",
     ):
         self.method = method.upper()
         self.url = URL(url) if isinstance(url, str) else url
@@ -298,22 +305,19 @@ class Response:
         self,
         status_code: int,
         *,
-        reason: typing.Optional[str] = None,
-        protocol: typing.Optional[str] = None,
-        headers: typing.List[typing.Tuple[bytes, bytes]] = [],
-        body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
+        reason_phrase: str = None,
+        protocol: str = None,
+        headers: HeaderTypes = None,
+        body: BodyTypes = b"",
         on_close: typing.Callable = None,
         request: Request = None,
         history: typing.List["Response"] = None,
     ):
         self.status_code = status_code
-        if not reason:
-            try:
-                self.reason = http.HTTPStatus(status_code).phrase
-            except ValueError as exc:
-                self.reason = ""
+        if reason_phrase is None:
+            self.reason_phrase = get_reason_phrase(status_code)
         else:
-            self.reason = reason
+            self.reason_phrase = reason_phrase
         self.protocol = protocol
         self.headers = Headers(headers)
         self.on_close = on_close
@@ -397,5 +401,13 @@ class Response:
     @property
     def is_redirect(self) -> bool:
         return (
-            self.status_code in (301, 302, 303, 307, 308) and "location" in self.headers
+            self.status_code
+            in (
+                codes.moved_permanently,
+                codes.found,
+                codes.see_other,
+                codes.temporary_redirect,
+                codes.permanent_redirect,
+            )
+            and "location" in self.headers
         )
index 2d58f9a1874b53c279ba7bd190691244d00c8dd9..907dc80b12bd45a10b125de35a6ddc284a12f4ec 100644 (file)
@@ -18,8 +18,8 @@ class SyncResponse:
         return self._response.status_code
 
     @property
-    def reason(self) -> str:
-        return self._response.reason
+    def reason_phrase(self) -> str:
+        return self._response.reason_phrase
 
     @property
     def headers(self) -> Headers:
index 419e7ec27424b1097d897dc155e887f78340ff08..aa5e14ee91249ccd1b6b236d23fa24efd539add6 100644 (file)
@@ -1,3 +1,4 @@
+import http
 import typing
 from urllib.parse import quote
 
@@ -69,3 +70,13 @@ def normalize_header_value(value: typing.AnyStr) -> bytes:
     if isinstance(value, bytes):
         return value
     return value.encode("latin-1")
+
+
+def get_reason_phrase(status_code: int) -> str:
+    """
+    Return an HTTP reason phrase, eg. "OK" for 200, or "Not Found" for 404.
+    """
+    try:
+        return http.HTTPStatus(status_code).phrase
+    except ValueError as exc:
+        return ""
index 4622849b561fe225a0c4ad54c273e52f7089bf15..09711f755814e73d376a43797197e3f76ff461c6 100644 (file)
@@ -5,18 +5,18 @@ import httpcore
 
 @pytest.mark.asyncio
 async def test_get(server):
+    url = "http://127.0.0.1:8000/"
     async with httpcore.Client() as client:
-        response = await client.request("GET", "http://127.0.0.1:8000/")
+        response = await client.get(url)
     assert response.status_code == 200
     assert response.body == b"Hello, world!"
 
 
 @pytest.mark.asyncio
 async def test_post(server):
+    url = "http://127.0.0.1:8000/"
     async with httpcore.Client() as client:
-        response = await client.request(
-            "POST", "http://127.0.0.1:8000/", body=b"Hello, world!"
-        )
+        response = await client.post(url, body=b"Hello, world!")
     assert response.status_code == 200
 
 
index e4ce64a46c3ebc9bbe0292e73cdd418b164aed56..411e1e7e208b7cd71f9bd43530cca11ca579fbfd 100644 (file)
@@ -1,6 +1,24 @@
+import ssl
+
+import pytest
+
 import httpcore
 
 
+@pytest.mark.asyncio
+async def test_load_ssl_config():
+    ssl_config = httpcore.SSLConfig()
+    context = await ssl_config.load_ssl_context()
+    assert context.verify_mode == ssl.VerifyMode.CERT_REQUIRED
+
+
+@pytest.mark.asyncio
+async def test_load_ssl_config_no_verify(verify=False):
+    ssl_config = httpcore.SSLConfig(verify=False)
+    context = await ssl_config.load_ssl_context()
+    assert context.verify_mode == ssl.VerifyMode.CERT_NONE
+
+
 def test_ssl_repr():
     ssl = httpcore.SSLConfig(verify=False)
     assert repr(ssl) == "SSLConfig(cert=None, verify=False)"
index bb930bdb204c21ff75259821024045740696d4f8..9bd3af80ec85549d56b6d905f07d5ee004d84026 100644 (file)
@@ -11,7 +11,7 @@ async def streaming_body():
 def test_response():
     response = httpcore.Response(200, body=b"Hello, world!")
     assert response.status_code == 200
-    assert response.reason == "OK"
+    assert response.reason_phrase == "OK"
     assert response.body == b"Hello, world!"
     assert response.is_closed
 
@@ -71,4 +71,4 @@ async def test_cannot_read_after_response_closed():
 def test_unknown_status_code():
     response = httpcore.Response(600)
     assert response.status_code == 600
-    assert response.reason == ""
+    assert response.reason_phrase == ""