]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Refactor tests to use `MockTransport(<handler_function>)` (#1281)
authorTom Christie <tom@tomchristie.com>
Sat, 12 Sep 2020 10:16:10 +0000 (11:16 +0100)
committerGitHub <noreply@github.com>
Sat, 12 Sep 2020 10:16:10 +0000 (11:16 +0100)
* Support Response(content=<bytes iterator>)

* Update test for merged master

* Add MockTransport for test cases

* Use MockTransport for redirect tests

* Reduce change footprint

* Reduce change footprint

* Clean up headers slightly

* Update requirements.txt

Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
requirements.txt
tests/client/test_cookies.py
tests/client/test_headers.py
tests/client/test_queryparams.py
tests/client/test_redirects.py
tests/test_multipart.py
tests/utils.py

index b871b15cdb6753cb76494b00d30d77541955b5c6..037fb668c436192375bf5f69e8142d6b2544f19f 100644 (file)
@@ -18,7 +18,7 @@ black==20.8b1
 cryptography
 flake8
 flake8-bugbear
-flake8-pie
+flake8-pie==0.5.*
 isort==5.*
 mypy
 pytest==5.*
index 8cd6be8394836523826e384406b45bc179235ba9..af614effb69d79b8571ebd8542c13fcbf0a05a99 100644 (file)
@@ -1,43 +1,19 @@
-import typing
+import json
 from http.cookiejar import Cookie, CookieJar
 
-import httpcore
-
 import httpx
-from httpx._content_streams import ByteStream, ContentStream, JSONStream
-
-
-def get_header_value(headers, key, default=None):
-    lookup = key.encode("ascii").lower()
-    for header_key, header_value in headers:
-        if header_key.lower() == lookup:
-            return header_value.decode("ascii")
-    return default
-
-
-class MockTransport(httpcore.SyncHTTPTransport):
-    def request(
-        self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]] = None,
-        stream: httpcore.SyncByteStream = None,
-        timeout: typing.Mapping[str, typing.Optional[float]] = None,
-    ) -> typing.Tuple[
-        bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
-    ]:
-        host, scheme, port, path = url
-        body: ContentStream
-        if path.startswith(b"/echo_cookies"):
-            cookie = get_header_value(headers, "cookie")
-            body = JSONStream({"cookies": cookie})
-            return b"HTTP/1.1", 200, b"OK", [], body
-        elif path.startswith(b"/set_cookie"):
-            headers = [(b"set-cookie", b"example-name=example-value")]
-            body = ByteStream(b"")
-            return b"HTTP/1.1", 200, b"OK", headers, body
-        else:
-            raise NotImplementedError()  # pragma: no cover
+from tests.utils import MockTransport
+
+
+def get_and_set_cookies(request: httpx.Request) -> httpx.Response:
+    if request.url.path == "/echo_cookies":
+        data = {"cookies": request.headers.get("cookie")}
+        content = json.dumps(data).encode("utf-8")
+        return httpx.Response(200, content=content)
+    elif request.url.path == "/set_cookie":
+        return httpx.Response(200, headers={"set-cookie": "example-name=example-value"})
+    else:
+        raise NotImplementedError()  # pragma: no cover
 
 
 def test_set_cookie() -> None:
@@ -47,7 +23,7 @@ def test_set_cookie() -> None:
     url = "http://example.org/echo_cookies"
     cookies = {"example-name": "example-value"}
 
-    client = httpx.Client(transport=MockTransport())
+    client = httpx.Client(transport=MockTransport(get_and_set_cookies))
     response = client.get(url, cookies=cookies)
 
     assert response.status_code == 200
@@ -82,7 +58,7 @@ def test_set_cookie_with_cookiejar() -> None:
     )
     cookies.set_cookie(cookie)
 
-    client = httpx.Client(transport=MockTransport())
+    client = httpx.Client(transport=MockTransport(get_and_set_cookies))
     response = client.get(url, cookies=cookies)
 
     assert response.status_code == 200
@@ -117,7 +93,7 @@ def test_setting_client_cookies_to_cookiejar() -> None:
     )
     cookies.set_cookie(cookie)
 
-    client = httpx.Client(transport=MockTransport())
+    client = httpx.Client(transport=MockTransport(get_and_set_cookies))
     client.cookies = cookies  # type: ignore
     response = client.get(url)
 
@@ -134,7 +110,7 @@ def test_set_cookie_with_cookies_model() -> None:
     cookies = httpx.Cookies()
     cookies["example-name"] = "example-value"
 
-    client = httpx.Client(transport=MockTransport())
+    client = httpx.Client(transport=MockTransport(get_and_set_cookies))
     response = client.get(url, cookies=cookies)
 
     assert response.status_code == 200
@@ -144,7 +120,7 @@ def test_set_cookie_with_cookies_model() -> None:
 def test_get_cookie() -> None:
     url = "http://example.org/set_cookie"
 
-    client = httpx.Client(transport=MockTransport())
+    client = httpx.Client(transport=MockTransport(get_and_set_cookies))
     response = client.get(url)
 
     assert response.status_code == 200
@@ -156,7 +132,7 @@ def test_cookie_persistence() -> None:
     """
     Ensure that Client instances persist cookies between requests.
     """
-    client = httpx.Client(transport=MockTransport())
+    client = httpx.Client(transport=MockTransport(get_and_set_cookies))
 
     response = client.get("http://example.org/echo_cookies")
     assert response.status_code == 200
index c86eae33c17a1d93407a47c85bafa2ac0b8935e7..d968616f4ec08e342f0a2086d90727601a15051d 100755 (executable)
@@ -1,31 +1,17 @@
 #!/usr/bin/env python3
 
-import typing
+import json
 
-import httpcore
 import pytest
 
 import httpx
-from httpx._content_streams import ContentStream, JSONStream
-
-
-class MockTransport(httpcore.SyncHTTPTransport):
-    def request(
-        self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]] = None,
-        stream: httpcore.SyncByteStream = None,
-        timeout: typing.Mapping[str, typing.Optional[float]] = None,
-    ) -> typing.Tuple[
-        bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
-    ]:
-        assert headers is not None
-        headers_dict = {
-            key.decode("ascii"): value.decode("ascii") for key, value in headers
-        }
-        body = JSONStream({"headers": headers_dict})
-        return b"HTTP/1.1", 200, b"OK", [], body
+from tests.utils import MockTransport
+
+
+def echo_headers(request: httpx.Request) -> httpx.Response:
+    data = {"headers": dict(request.headers)}
+    content = json.dumps(data).encode("utf-8")
+    return httpx.Response(200, content=content)
 
 
 def test_client_header():
@@ -35,7 +21,7 @@ def test_client_header():
     url = "http://example.org/echo_headers"
     headers = {"Example-Header": "example-value"}
 
-    client = httpx.Client(transport=MockTransport(), headers=headers)
+    client = httpx.Client(transport=MockTransport(echo_headers), headers=headers)
     response = client.get(url)
 
     assert response.status_code == 200
@@ -55,7 +41,7 @@ def test_header_merge():
     url = "http://example.org/echo_headers"
     client_headers = {"User-Agent": "python-myclient/0.2.1"}
     request_headers = {"X-Auth-Token": "FooBarBazToken"}
-    client = httpx.Client(transport=MockTransport(), headers=client_headers)
+    client = httpx.Client(transport=MockTransport(echo_headers), headers=client_headers)
     response = client.get(url, headers=request_headers)
 
     assert response.status_code == 200
@@ -75,7 +61,7 @@ def test_header_merge_conflicting_headers():
     url = "http://example.org/echo_headers"
     client_headers = {"X-Auth-Token": "FooBar"}
     request_headers = {"X-Auth-Token": "BazToken"}
-    client = httpx.Client(transport=MockTransport(), headers=client_headers)
+    client = httpx.Client(transport=MockTransport(echo_headers), headers=client_headers)
     response = client.get(url, headers=request_headers)
 
     assert response.status_code == 200
@@ -93,7 +79,7 @@ def test_header_merge_conflicting_headers():
 
 def test_header_update():
     url = "http://example.org/echo_headers"
-    client = httpx.Client(transport=MockTransport())
+    client = httpx.Client(transport=MockTransport(echo_headers))
     first_response = client.get(url)
     client.headers.update(
         {"User-Agent": "python-myclient/0.2.1", "Another-Header": "AThing"}
@@ -130,7 +116,7 @@ def test_remove_default_header():
     """
     url = "http://example.org/echo_headers"
 
-    client = httpx.Client(transport=MockTransport())
+    client = httpx.Client(transport=MockTransport(echo_headers))
     del client.headers["User-Agent"]
 
     response = client.get(url)
@@ -160,7 +146,7 @@ def test_host_with_auth_and_port_in_url():
     """
     url = "http://username:password@example.org:80/echo_headers"
 
-    client = httpx.Client(transport=MockTransport())
+    client = httpx.Client(transport=MockTransport(echo_headers))
     response = client.get(url)
 
     assert response.status_code == 200
@@ -183,7 +169,7 @@ def test_host_with_non_default_port_in_url():
     """
     url = "http://username:password@example.org:123/echo_headers"
 
-    client = httpx.Client(transport=MockTransport())
+    client = httpx.Client(transport=MockTransport(echo_headers))
     response = client.get(url)
 
     assert response.status_code == 200
index 22f715dadc2fa610461b45a99c9bd176ba3b1121..39731d5bb0210c8d7302f404e750b2e9d03f0a55 100644 (file)
@@ -1,24 +1,9 @@
-import typing
-
-import httpcore
-
 import httpx
-from httpx._content_streams import ContentStream, JSONStream
+from tests.utils import MockTransport
 
 
-class MockTransport(httpcore.SyncHTTPTransport):
-    def request(
-        self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]] = None,
-        stream: httpcore.SyncByteStream = None,
-        timeout: typing.Mapping[str, typing.Optional[float]] = None,
-    ) -> typing.Tuple[
-        bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
-    ]:
-        body = JSONStream({"ok": "ok"})
-        return b"HTTP/1.1", 200, b"OK", [], body
+def hello_world(request: httpx.Request) -> httpx.Response:
+    return httpx.Response(200, content=b"Hello, world")
 
 
 def test_client_queryparams():
@@ -42,7 +27,9 @@ def test_client_queryparams_echo():
     url = "http://example.org/echo_queryparams"
     client_queryparams = "first=str"
     request_queryparams = {"second": "dict"}
-    client = httpx.Client(transport=MockTransport(), params=client_queryparams)
+    client = httpx.Client(
+        transport=MockTransport(hello_world), params=client_queryparams
+    )
     response = client.get(url, params=request_queryparams)
 
     assert response.status_code == 200
index 4b00133e313168ed3220fa8be4832563ec975212..63fcd32087842e887b434faccd69de4b3f36440f 100644 (file)
 import json
-import typing
-from urllib.parse import parse_qs
 
 import httpcore
 import pytest
 
 import httpx
-from httpx._content_streams import ByteStream, ContentStream, IteratorStream
-
-
-def get_header_value(headers, key, default=None):
-    lookup = key.encode("ascii").lower()
-    for header_key, header_value in headers:
-        if header_key.lower() == lookup:
-            return header_value.decode("ascii")
-    return default
-
-
-class MockTransport:
-    def _request(
-        self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, int, bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]],
-        stream: ContentStream,
-        timeout: typing.Mapping[str, typing.Optional[float]] = None,
-    ) -> typing.Tuple[
-        bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
-    ]:
-        scheme, host, port, path = url
-        if scheme not in (b"http", b"https"):
-            raise httpcore.UnsupportedProtocol(f"Scheme {scheme!r} not supported.")
-
-        path, _, query = path.partition(b"?")
-        if path == b"/no_redirect":
-            return b"HTTP/1.1", httpx.codes.OK, b"OK", [], ByteStream(b"")
-
-        elif path == b"/redirect_301":
-
-            def body():
-                yield b"<a href='https://example.org/'>here</a>"
-
-            status_code = httpx.codes.MOVED_PERMANENTLY
-            headers = [(b"location", b"https://example.org/")]
-            stream = IteratorStream(iterator=body())
-            return b"HTTP/1.1", status_code, b"Moved Permanently", headers, stream
-
-        elif path == b"/redirect_302":
-            status_code = httpx.codes.FOUND
-            headers = [(b"location", b"https://example.org/")]
-            return b"HTTP/1.1", status_code, b"Found", headers, ByteStream(b"")
-
-        elif path == b"/redirect_303":
-            status_code = httpx.codes.SEE_OTHER
-            headers = [(b"location", b"https://example.org/")]
-            return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
-
-        elif path == b"/relative_redirect":
-            status_code = httpx.codes.SEE_OTHER
-            headers = [(b"location", b"/")]
-            return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
-
-        elif path == b"/malformed_redirect":
-            status_code = httpx.codes.SEE_OTHER
-            headers = [(b"location", b"https://:443/")]
-            return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
-
-        elif path == b"/invalid_redirect":
-            status_code = httpx.codes.SEE_OTHER
-            headers = [(b"location", "https://😇/".encode("utf-8"))]
-            return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
-
-        elif path == b"/no_scheme_redirect":
-            status_code = httpx.codes.SEE_OTHER
-            headers = [(b"location", b"//example.org/")]
-            return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
-
-        elif path == b"/multiple_redirects":
-            params = parse_qs(query.decode("ascii"))
-            count = int(params.get("count", "0")[0])
-            redirect_count = count - 1
-            code = httpx.codes.SEE_OTHER if count else httpx.codes.OK
-            phrase = b"See Other" if count else b"OK"
-            location = b"/multiple_redirects"
+from tests.utils import AsyncMockTransport, MockTransport
+
+
+def redirects(request: httpx.Request) -> httpx.Response:
+    if request.url.scheme not in ("http", "https"):
+        raise httpcore.UnsupportedProtocol(
+            f"Scheme {request.url.scheme!r} not supported."
+        )
+
+    if request.url.path == "/no_redirect":
+        return httpx.Response(200)
+
+    elif request.url.path == "/redirect_301":
+        status_code = httpx.codes.MOVED_PERMANENTLY
+        content = b"<a href='https://example.org/'>here</a>"
+        headers = {"location": "https://example.org/"}
+        return httpx.Response(status_code, headers=headers, content=content)
+
+    elif request.url.path == "/redirect_302":
+        status_code = httpx.codes.FOUND
+        headers = {"location": "https://example.org/"}
+        return httpx.Response(status_code, headers=headers)
+
+    elif request.url.path == "/redirect_303":
+        status_code = httpx.codes.SEE_OTHER
+        headers = {"location": "https://example.org/"}
+        return httpx.Response(status_code, headers=headers)
+
+    elif request.url.path == "/relative_redirect":
+        status_code = httpx.codes.SEE_OTHER
+        headers = {"location": "/"}
+        return httpx.Response(status_code, headers=headers)
+
+    elif request.url.path == "/malformed_redirect":
+        status_code = httpx.codes.SEE_OTHER
+        headers = {"location": "https://:443/"}
+        return httpx.Response(status_code, headers=headers)
+
+    elif request.url.path == "/invalid_redirect":
+        status_code = httpx.codes.SEE_OTHER
+        raw_headers = [(b"location", "https://😇/".encode("utf-8"))]
+        return httpx.Response(status_code, headers=raw_headers)
+
+    elif request.url.path == "/no_scheme_redirect":
+        status_code = httpx.codes.SEE_OTHER
+        headers = {"location": "//example.org/"}
+        return httpx.Response(status_code, headers=headers)
+
+    elif request.url.path == "/multiple_redirects":
+        params = httpx.QueryParams(request.url.query)
+        count = int(params.get("count", "0"))
+        redirect_count = count - 1
+        status_code = httpx.codes.SEE_OTHER if count else httpx.codes.OK
+        if count:
+            location = "/multiple_redirects"
             if redirect_count:
-                location += b"?count=" + str(redirect_count).encode("ascii")
-            headers = [(b"location", location)] if count else []
-            return b"HTTP/1.1", code, phrase, headers, ByteStream(b"")
-
-        if path == b"/redirect_loop":
-            code = httpx.codes.SEE_OTHER
-            headers = [(b"location", b"/redirect_loop")]
-            return b"HTTP/1.1", code, b"See Other", headers, ByteStream(b"")
-
-        elif path == b"/cross_domain":
-            code = httpx.codes.SEE_OTHER
-            headers = [(b"location", b"https://example.org/cross_domain_target")]
-            return b"HTTP/1.1", code, b"See Other", headers, ByteStream(b"")
-
-        elif path == b"/cross_domain_target":
-            headers_dict = {
-                key.decode("ascii"): value.decode("ascii") for key, value in headers
-            }
-            stream = ByteStream(json.dumps({"headers": headers_dict}).encode())
-            return b"HTTP/1.1", 200, b"OK", [], stream
-
-        elif path == b"/redirect_body":
-            code = httpx.codes.PERMANENT_REDIRECT
-            headers = [(b"location", b"/redirect_body_target")]
-            return b"HTTP/1.1", code, b"Permanent Redirect", headers, ByteStream(b"")
-
-        elif path == b"/redirect_no_body":
-            code = httpx.codes.SEE_OTHER
-            headers = [(b"location", b"/redirect_body_target")]
-            return b"HTTP/1.1", code, b"See Other", headers, ByteStream(b"")
-
-        elif path == b"/redirect_body_target":
-            content = b"".join(stream)
-            headers_dict = {
-                key.decode("ascii"): value.decode("ascii") for key, value in headers
-            }
-            stream = ByteStream(
-                json.dumps({"body": content.decode(), "headers": headers_dict}).encode()
-            )
-            return b"HTTP/1.1", 200, b"OK", [], stream
-
-        elif path == b"/cross_subdomain":
-            host = get_header_value(headers, "host")
-            if host != "www.example.org":
-                headers = [(b"location", b"https://www.example.org/cross_subdomain")]
-                return (
-                    b"HTTP/1.1",
-                    httpx.codes.PERMANENT_REDIRECT,
-                    b"Permanent Redirect",
-                    headers,
-                    ByteStream(b""),
-                )
-            else:
-                return b"HTTP/1.1", 200, b"OK", [], ByteStream(b"Hello, world!")
-
-        elif path == b"/redirect_custom_scheme":
-            status_code = httpx.codes.MOVED_PERMANENTLY
-            headers = [(b"location", b"market://details?id=42")]
-            return (
-                b"HTTP/1.1",
-                status_code,
-                b"Moved Permanently",
-                headers,
-                ByteStream(b""),
-            )
-
-        stream = ByteStream(b"Hello, world!") if method != b"HEAD" else ByteStream(b"")
-
-        return b"HTTP/1.1", 200, b"OK", [], stream
-
-
-class AsyncMockTransport(MockTransport, httpcore.AsyncHTTPTransport):
-    async def request(
-        self, *args, **kwargs
-    ) -> typing.Tuple[
-        bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
-    ]:
-        return self._request(*args, **kwargs)
-
-
-class SyncMockTransport(MockTransport, httpcore.SyncHTTPTransport):
-    def request(
-        self, *args, **kwargs
-    ) -> typing.Tuple[
-        bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
-    ]:
-        return self._request(*args, **kwargs)
+                location += f"?count={redirect_count}"
+            headers = {"location": location}
+        else:
+            headers = {}
+        return httpx.Response(status_code, headers=headers)
+
+    if request.url.path == "/redirect_loop":
+        status_code = httpx.codes.SEE_OTHER
+        headers = {"location": "/redirect_loop"}
+        return httpx.Response(status_code, headers=headers)
+
+    elif request.url.path == "/cross_domain":
+        status_code = httpx.codes.SEE_OTHER
+        headers = {"location": "https://example.org/cross_domain_target"}
+        return httpx.Response(status_code, headers=headers)
+
+    elif request.url.path == "/cross_domain_target":
+        status_code = httpx.codes.OK
+        content = json.dumps({"headers": dict(request.headers)}).encode("utf-8")
+        return httpx.Response(status_code, content=content)
+
+    elif request.url.path == "/redirect_body":
+        status_code = httpx.codes.PERMANENT_REDIRECT
+        headers = {"location": "/redirect_body_target"}
+        return httpx.Response(status_code, headers=headers)
+
+    elif request.url.path == "/redirect_no_body":
+        status_code = httpx.codes.SEE_OTHER
+        headers = {"location": "/redirect_body_target"}
+        return httpx.Response(status_code, headers=headers)
+
+    elif request.url.path == "/redirect_body_target":
+        content = json.dumps(
+            {"body": request.content.decode("ascii"), "headers": dict(request.headers)}
+        ).encode("utf-8")
+        return httpx.Response(200, content=content)
+
+    elif request.url.path == "/cross_subdomain":
+        if request.headers["Host"] != "www.example.org":
+            status_code = httpx.codes.PERMANENT_REDIRECT
+            headers = {"location": "https://www.example.org/cross_subdomain"}
+            return httpx.Response(status_code, headers=headers)
+        else:
+            return httpx.Response(200, content=b"Hello, world!")
+
+    elif request.url.path == "/redirect_custom_scheme":
+        status_code = httpx.codes.MOVED_PERMANENTLY
+        headers = {"location": "market://details?id=42"}
+        return httpx.Response(status_code, headers=headers)
+
+    if request.method == "HEAD":
+        return httpx.Response(200)
+
+    return httpx.Response(200, content=b"Hello, world!")
 
 
 def test_no_redirect():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     url = "https://example.com/no_redirect"
     response = client.get(url)
     assert response.status_code == 200
@@ -183,7 +126,7 @@ def test_no_redirect():
 
 
 def test_redirect_301():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     response = client.post("https://example.org/redirect_301")
     assert response.status_code == httpx.codes.OK
     assert response.url == "https://example.org/"
@@ -191,7 +134,7 @@ def test_redirect_301():
 
 
 def test_redirect_302():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     response = client.post("https://example.org/redirect_302")
     assert response.status_code == httpx.codes.OK
     assert response.url == "https://example.org/"
@@ -199,7 +142,7 @@ def test_redirect_302():
 
 
 def test_redirect_303():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     response = client.get("https://example.org/redirect_303")
     assert response.status_code == httpx.codes.OK
     assert response.url == "https://example.org/"
@@ -207,7 +150,7 @@ def test_redirect_303():
 
 
 def test_disallow_redirects():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     response = client.post("https://example.org/redirect_303", allow_redirects=False)
     assert response.status_code == httpx.codes.SEE_OTHER
     assert response.url == "https://example.org/redirect_303"
@@ -225,7 +168,7 @@ def test_head_redirect():
     """
     Contrary to Requests, redirects remain enabled by default for HEAD requests.
     """
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     response = client.head("https://example.org/redirect_302")
     assert response.status_code == httpx.codes.OK
     assert response.url == "https://example.org/"
@@ -235,7 +178,7 @@ def test_head_redirect():
 
 
 def test_relative_redirect():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     response = client.get("https://example.org/relative_redirect")
     assert response.status_code == httpx.codes.OK
     assert response.url == "https://example.org/"
@@ -244,7 +187,7 @@ def test_relative_redirect():
 
 def test_malformed_redirect():
     # https://github.com/encode/httpx/issues/771
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     response = client.get("http://example.org/malformed_redirect")
     assert response.status_code == httpx.codes.OK
     assert response.url == "https://example.org:443/"
@@ -252,13 +195,13 @@ def test_malformed_redirect():
 
 
 def test_invalid_redirect():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     with pytest.raises(httpx.RemoteProtocolError):
         client.get("http://example.org/invalid_redirect")
 
 
 def test_no_scheme_redirect():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     response = client.get("https://example.org/no_scheme_redirect")
     assert response.status_code == httpx.codes.OK
     assert response.url == "https://example.org/"
@@ -266,7 +209,7 @@ def test_no_scheme_redirect():
 
 
 def test_fragment_redirect():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     response = client.get("https://example.org/relative_redirect#fragment")
     assert response.status_code == httpx.codes.OK
     assert response.url == "https://example.org/#fragment"
@@ -274,7 +217,7 @@ def test_fragment_redirect():
 
 
 def test_multiple_redirects():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     response = client.get("https://example.org/multiple_redirects?count=20")
     assert response.status_code == httpx.codes.OK
     assert response.url == "https://example.org/multiple_redirects"
@@ -287,14 +230,14 @@ def test_multiple_redirects():
 
 @pytest.mark.usefixtures("async_environment")
 async def test_async_too_many_redirects():
-    async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(redirects)) as client:
         with pytest.raises(httpx.TooManyRedirects):
             await client.get("https://example.org/multiple_redirects?count=21")
 
 
 @pytest.mark.usefixtures("async_environment")
 async def test_async_too_many_redirects_calling_next():
-    async with httpx.AsyncClient(transport=AsyncMockTransport()) as client:
+    async with httpx.AsyncClient(transport=AsyncMockTransport(redirects)) as client:
         url = "https://example.org/multiple_redirects?count=21"
         response = await client.get(url, allow_redirects=False)
         with pytest.raises(httpx.TooManyRedirects):
@@ -303,13 +246,13 @@ async def test_async_too_many_redirects_calling_next():
 
 
 def test_sync_too_many_redirects():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     with pytest.raises(httpx.TooManyRedirects):
         client.get("https://example.org/multiple_redirects?count=21")
 
 
 def test_sync_too_many_redirects_calling_next():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     url = "https://example.org/multiple_redirects?count=21"
     response = client.get(url, allow_redirects=False)
     with pytest.raises(httpx.TooManyRedirects):
@@ -318,13 +261,13 @@ def test_sync_too_many_redirects_calling_next():
 
 
 def test_redirect_loop():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     with pytest.raises(httpx.TooManyRedirects):
         client.get("https://example.org/redirect_loop")
 
 
 def test_cross_domain_redirect_with_auth_header():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     url = "https://example.com/cross_domain"
     headers = {"Authorization": "abc"}
     response = client.get(url, headers=headers)
@@ -333,7 +276,7 @@ def test_cross_domain_redirect_with_auth_header():
 
 
 def test_cross_domain_redirect_with_auth():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     url = "https://example.com/cross_domain"
     response = client.get(url, auth=("user", "pass"))
     assert response.url == "https://example.org/cross_domain_target"
@@ -341,7 +284,7 @@ def test_cross_domain_redirect_with_auth():
 
 
 def test_same_domain_redirect():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     url = "https://example.org/cross_domain"
     headers = {"Authorization": "abc"}
     response = client.get(url, headers=headers)
@@ -353,7 +296,7 @@ def test_body_redirect():
     """
     A 308 redirect should preserve the request body.
     """
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     url = "https://example.org/redirect_body"
     data = b"Example request body"
     response = client.post(url, data=data)
@@ -366,7 +309,7 @@ def test_no_body_redirect():
     """
     A 303 redirect should remove the request body.
     """
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     url = "https://example.org/redirect_no_body"
     data = b"Example request body"
     response = client.post(url, data=data)
@@ -376,7 +319,7 @@ def test_no_body_redirect():
 
 
 def test_can_stream_if_no_redirect():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     url = "https://example.org/redirect_301"
     with client.stream("GET", url, allow_redirects=False) as response:
         assert not response.is_closed
@@ -385,7 +328,7 @@ def test_can_stream_if_no_redirect():
 
 
 def test_cannot_redirect_streaming_body():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     url = "https://example.org/redirect_body"
 
     def streaming_body():
@@ -396,64 +339,47 @@ def test_cannot_redirect_streaming_body():
 
 
 def test_cross_subdomain_redirect():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     url = "https://example.com/cross_subdomain"
     response = client.get(url)
     assert response.url == "https://www.example.org/cross_subdomain"
 
 
-class MockCookieTransport(httpcore.SyncHTTPTransport):
-    def request(
-        self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]] = None,
-        stream: httpcore.SyncByteStream = None,
-        timeout: typing.Mapping[str, typing.Optional[float]] = None,
-    ) -> typing.Tuple[
-        bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream
-    ]:
-        scheme, host, port, path = url
-        if path == b"/":
-            cookie = get_header_value(headers, "Cookie")
-            if cookie is not None:
-                content = b"Logged in"
-            else:
-                content = b"Not logged in"
-            return b"HTTP/1.1", 200, b"OK", [], ByteStream(content)
-
-        elif path == b"/login":
-            status_code = httpx.codes.SEE_OTHER
-            headers = [
-                (b"location", b"/"),
-                (
-                    b"set-cookie",
-                    (
-                        b"session=eyJ1c2VybmFtZSI6ICJ0b21; path=/; Max-Age=1209600; "
-                        b"httponly; samesite=lax"
-                    ),
-                ),
-            ]
-            return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
-
+def cookie_sessions(request: httpx.Request) -> httpx.Response:
+    if request.url.path == "/":
+        cookie = request.headers.get("Cookie")
+        if cookie is not None:
+            content = b"Logged in"
         else:
-            assert path == b"/logout"
-            status_code = httpx.codes.SEE_OTHER
-            headers = [
-                (b"location", b"/"),
-                (
-                    b"set-cookie",
-                    (
-                        b"session=null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT; "
-                        b"httponly; samesite=lax"
-                    ),
-                ),
-            ]
-            return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"")
+            content = b"Not logged in"
+        return httpx.Response(200, content=content)
+
+    elif request.url.path == "/login":
+        status_code = httpx.codes.SEE_OTHER
+        headers = {
+            "location": "/",
+            "set-cookie": (
+                "session=eyJ1c2VybmFtZSI6ICJ0b21; path=/; Max-Age=1209600; "
+                "httponly; samesite=lax"
+            ),
+        }
+        return httpx.Response(status_code, headers=headers)
+
+    else:
+        assert request.url.path == "/logout"
+        status_code = httpx.codes.SEE_OTHER
+        headers = {
+            "location": "/",
+            "set-cookie": (
+                "session=null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT; "
+                "httponly; samesite=lax"
+            ),
+        }
+        return httpx.Response(status_code, headers=headers)
 
 
 def test_redirect_cookie_behavior():
-    client = httpx.Client(transport=MockCookieTransport())
+    client = httpx.Client(transport=MockTransport(cookie_sessions))
 
     # The client is not logged in.
     response = client.get("https://example.com/")
@@ -482,7 +408,7 @@ def test_redirect_cookie_behavior():
 
 
 def test_redirect_custom_scheme():
-    client = httpx.Client(transport=SyncMockTransport())
+    client = httpx.Client(transport=MockTransport(redirects))
     with pytest.raises(httpx.UnsupportedProtocol) as e:
         client.post("https://example.org/redirect_custom_scheme")
-    assert str(e.value) == "Scheme b'market' not supported."
+    assert str(e.value) == "Scheme 'market' not supported."
index f4962daba0b7fb18b7013a7f7e697da8562127c9..d10c39038de3fd6d2e8dcb71dcf82631ae3ea0d1 100644 (file)
@@ -4,37 +4,21 @@ import os
 import typing
 from unittest import mock
 
-import httpcore
 import pytest
 
 import httpx
 from httpx._content_streams import MultipartStream, encode
 from httpx._utils import format_form_param
+from tests.utils import MockTransport
 
 
-class MockTransport(httpcore.SyncHTTPTransport):
-    def request(
-        self,
-        method: bytes,
-        url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes],
-        headers: typing.List[typing.Tuple[bytes, bytes]] = None,
-        stream: httpcore.SyncByteStream = None,
-        timeout: typing.Mapping[str, typing.Optional[float]] = None,
-    ) -> typing.Tuple[
-        bytes,
-        int,
-        bytes,
-        typing.List[typing.Tuple[bytes, bytes]],
-        httpcore.SyncByteStream,
-    ]:
-        assert stream is not None
-        content = httpcore.IteratorByteStream(iterator=(part for part in stream))
-        return b"HTTP/1.1", 200, b"OK", [], content
+def echo_request_content(request: httpx.Request) -> httpx.Response:
+    return httpx.Response(200, content=request.content)
 
 
 @pytest.mark.parametrize(("value,output"), (("abc", b"abc"), (b"abc", b"abc")))
 def test_multipart(value, output):
-    client = httpx.Client(transport=MockTransport())
+    client = httpx.Client(transport=MockTransport(echo_request_content))
 
     # Test with a single-value 'data' argument, and a plain file 'files' argument.
     data = {"text": value}
@@ -60,7 +44,7 @@ def test_multipart(value, output):
 
 @pytest.mark.parametrize(("key"), (b"abc", 1, 2.3, None))
 def test_multipart_invalid_key(key):
-    client = httpx.Client(transport=MockTransport())
+    client = httpx.Client(transport=MockTransport(echo_request_content))
 
     data = {key: "abc"}
     files = {"file": io.BytesIO(b"<file content>")}
@@ -75,7 +59,7 @@ def test_multipart_invalid_key(key):
 
 @pytest.mark.parametrize(("value"), (1, 2.3, None, [None, "abc"], {None: "abc"}))
 def test_multipart_invalid_value(value):
-    client = httpx.Client(transport=MockTransport())
+    client = httpx.Client(transport=MockTransport(echo_request_content))
 
     data = {"text": value}
     files = {"file": io.BytesIO(b"<file content>")}
@@ -85,7 +69,7 @@ def test_multipart_invalid_value(value):
 
 
 def test_multipart_file_tuple():
-    client = httpx.Client(transport=MockTransport())
+    client = httpx.Client(transport=MockTransport(echo_request_content))
 
     # Test with a list of values 'data' argument,
     #     and a tuple style 'files' argument.
index e2636a535cfb5ac2dcc96fb3cfd974ae7ada134a..ee319e0010183d4eb7e93212b079a2f863a06454 100644 (file)
@@ -1,7 +1,11 @@
 import contextlib
 import logging
 import os
+from typing import Callable, List, Mapping, Optional, Tuple
 
+import httpcore
+
+import httpx
 from httpx import _utils
 
 
@@ -18,3 +22,90 @@ def override_log_level(log_level: str):
     finally:
         # Reset the logger so we don't have verbose output in all unit tests
         logging.getLogger("httpx").handlers = []
+
+
+class MockTransport(httpcore.SyncHTTPTransport):
+    def __init__(self, handler: Callable) -> None:
+        self.handler = handler
+
+    def request(
+        self,
+        method: bytes,
+        url: Tuple[bytes, bytes, Optional[int], bytes],
+        headers: List[Tuple[bytes, bytes]] = None,
+        stream: httpcore.SyncByteStream = None,
+        timeout: Mapping[str, Optional[float]] = None,
+    ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.SyncByteStream]:
+        raw_scheme, raw_host, port, raw_path = url
+        scheme = raw_scheme.decode("ascii")
+        host = raw_host.decode("ascii")
+        port_str = "" if port is None else f":{port}"
+        path = raw_path.decode("ascii")
+
+        request_headers = httpx.Headers(headers)
+        data = (
+            (item for item in stream)
+            if stream
+            and (
+                "Content-Length" in request_headers
+                or "Transfer-Encoding" in request_headers
+            )
+            else None
+        )
+
+        request = httpx.Request(
+            method=method.decode("ascii"),
+            url=f"{scheme}://{host}{port_str}{path}",
+            headers=request_headers,
+            data=data,
+        )
+        request.read()
+        response = self.handler(request)
+        return (
+            response.http_version.encode("ascii")
+            if response.http_version
+            else b"HTTP/1.1",
+            response.status_code,
+            response.reason_phrase.encode("ascii"),
+            response.headers.raw,
+            response._raw_stream,
+        )
+
+
+class AsyncMockTransport(httpcore.AsyncHTTPTransport):
+    def __init__(self, handler: Callable) -> None:
+        self.impl = MockTransport(handler)
+
+    async def request(
+        self,
+        method: bytes,
+        url: Tuple[bytes, bytes, Optional[int], bytes],
+        headers: List[Tuple[bytes, bytes]] = None,
+        stream: httpcore.AsyncByteStream = None,
+        timeout: Mapping[str, Optional[float]] = None,
+    ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]:
+        content = (
+            httpcore.PlainByteStream(b"".join([part async for part in stream]))
+            if stream
+            else httpcore.PlainByteStream(b"")
+        )
+
+        (
+            http_version,
+            status_code,
+            reason_phrase,
+            headers,
+            response_stream,
+        ) = self.impl.request(
+            method, url, headers=headers, stream=content, timeout=timeout
+        )
+
+        content = httpcore.PlainByteStream(b"".join([part for part in response_stream]))
+
+        return (
+            http_version,
+            status_code,
+            reason_phrase,
+            headers,
+            content,
+        )