]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Redirections
authorTom Christie <tom@tomchristie.com>
Mon, 29 Apr 2019 11:35:58 +0000 (12:35 +0100)
committerTom Christie <tom@tomchristie.com>
Mon, 29 Apr 2019 11:35:58 +0000 (12:35 +0100)
httpcore/__init__.py
httpcore/adapters/redirects.py
httpcore/exceptions.py
httpcore/models.py
tests/adapters/test_redirects.py

index d9c1364c7f57f0672def5e1abcd47d20d5142e42..fdd992534fd338df34f0e7fab4290ee4daa898c5 100644 (file)
@@ -10,10 +10,11 @@ from .exceptions import (
     PoolTimeout,
     ProtocolError,
     ReadTimeout,
+    RedirectLoop,
     ResponseClosed,
     StreamConsumed,
     Timeout,
-    TooManyRedirects
+    TooManyRedirects,
 )
 from .interfaces import Adapter
 from .models import URL, Headers, Origin, Request, Response
index 0cf3f9fbb740eb741a7863c763b54aad4b12cad7..9ef287f720a1e0d9ece961404a4630fe69d4da8f 100644 (file)
@@ -2,7 +2,7 @@ import typing
 from urllib.parse import urljoin, urlparse
 
 from ..config import DEFAULT_MAX_REDIRECTS
-from ..exceptions import TooManyRedirects
+from ..exceptions import RedirectLoop, TooManyRedirects
 from ..interfaces import Adapter
 from ..models import URL, Request, Response
 from ..status_codes import codes
@@ -20,6 +20,7 @@ class RedirectAdapter(Adapter):
     async def send(self, request: Request, **options: typing.Any) -> Response:
         allow_redirects = options.pop("allow_redirects", True)
         history = []
+        seen_urls = set((request.url,))
 
         while True:
             response = await self.dispatch.send(request, **options)
@@ -29,6 +30,9 @@ class RedirectAdapter(Adapter):
             if len(history) > self.max_redirects:
                 raise TooManyRedirects()
             request = self.build_redirect_request(request, response)
+            if request.url in seen_urls:
+                raise RedirectLoop()
+            seen_urls.add(request.url)
 
         return response
 
index 337b74a2e4f96a8fb282ad4fe3879deef3eea186..12aac53b852560931766eacc32e0dabc6ed2c16f 100644 (file)
@@ -28,12 +28,24 @@ class PoolTimeout(Timeout):
     """
 
 
-class TooManyRedirects(Exception):
+class RedirectError(Exception):
+    """
+    Base class for HTTP redirect errors.
+    """
+
+
+class TooManyRedirects(RedirectError):
     """
     Too many redirects.
     """
 
 
+class RedirectLoop(RedirectError):
+    """
+    Infinite redirect loop.
+    """
+
+
 class ProtocolError(Exception):
     """
     Malformed HTTP.
index b08da934a346387e5e5c0b44f204e55f3339649b..b0a4723d6a1d4621820d5afa33298fb7e21d9176 100644 (file)
@@ -70,6 +70,12 @@ class URL:
     def origin(self) -> "Origin":
         return Origin(self)
 
+    def __hash__(self) -> int:
+        return hash(str(self))
+
+    def __eq__(self, other: typing.Any) -> bool:
+        return isinstance(other, URL) and str(self) == str(other)
+
     def __str__(self) -> str:
         return self.components.geturl()
 
index dbcb2338570669ed34614a9b2c814479ed3f3316..3609dd7342d972e3f6b8815df5eaedd87c3da575 100644 (file)
@@ -1,7 +1,16 @@
-import pytest
 from urllib.parse import parse_qs
 
-from httpcore import Adapter, RedirectAdapter, Request, Response, TooManyRedirects, codes
+import pytest
+
+from httpcore import (
+    Adapter,
+    RedirectAdapter,
+    RedirectLoop,
+    Request,
+    Response,
+    TooManyRedirects,
+    codes,
+)
 
 
 class MockDispatch(Adapter):
@@ -9,12 +18,21 @@ class MockDispatch(Adapter):
         pass
 
     async def send(self, request: Request, **options) -> Response:
-        if request.url.path == "/redirect_303":
-            return Response(
-                codes.see_other, headers=[(b"location", b"https://example.org/")]
-            )
+        if request.url.path == "/redirect_301":  # "Moved Permanently"
+            return Response(301, headers=[(b"location", b"https://example.org/")])
+
+        elif request.url.path == "/redirect_302":  # "Found"
+            return Response(302, headers=[(b"location", b"https://example.org/")])
+
+        elif request.url.path == "/redirect_303":  # "See Other"
+            return Response(303, headers=[(b"location", b"https://example.org/")])
+
         elif request.url.path == "/relative_redirect":
             return Response(codes.see_other, headers=[(b"location", b"/")])
+
+        elif request.url.path == "/no_scheme_redirect":
+            return Response(codes.see_other, headers=[(b"location", b"//example.org/")])
+
         elif request.url.path == "/multiple_redirects":
             params = parse_qs(request.url.query)
             count = int(params.get("count", "0")[0])
@@ -22,9 +40,27 @@ class MockDispatch(Adapter):
             location = "/multiple_redirects?count=" + str(count - 1)
             headers = [(b"location", location.encode())] if count else []
             return Response(code, headers=headers)
+
+        if request.url.path == "/redirect_loop":
+            return Response(codes.see_other, headers=[(b"location", b"/redirect_loop")])
+
         return Response(codes.ok, body=b"Hello, world!")
 
 
+@pytest.mark.asyncio
+async def test_redirect_301():
+    client = RedirectAdapter(MockDispatch())
+    response = await client.request("POST", "https://example.org/redirect_301")
+    assert response.status_code == codes.ok
+
+
+@pytest.mark.asyncio
+async def test_redirect_302():
+    client = RedirectAdapter(MockDispatch())
+    response = await client.request("POST", "https://example.org/redirect_302")
+    assert response.status_code == codes.ok
+
+
 @pytest.mark.asyncio
 async def test_redirect_303():
     client = RedirectAdapter(MockDispatch())
@@ -39,10 +75,26 @@ async def test_relative_redirect():
     assert response.status_code == codes.ok
 
 
+@pytest.mark.asyncio
+async def test_no_scheme_redirect():
+    client = RedirectAdapter(MockDispatch())
+    response = await client.request("GET", "https://example.org/no_scheme_redirect")
+    assert response.status_code == codes.ok
+
+
+@pytest.mark.asyncio
+async def test_fragment_redirect():
+    client = RedirectAdapter(MockDispatch())
+    response = await client.request("GET", "https://example.org/relative_redirect#fragment")
+    assert response.status_code == codes.ok
+
+
 @pytest.mark.asyncio
 async def test_multiple_redirects():
     client = RedirectAdapter(MockDispatch())
-    response = await client.request("GET", "https://example.org/multiple_redirects?count=20")
+    response = await client.request(
+        "GET", "https://example.org/multiple_redirects?count=20"
+    )
     assert response.status_code == codes.ok
 
 
@@ -51,3 +103,10 @@ async def test_too_many_redirects():
     client = RedirectAdapter(MockDispatch())
     with pytest.raises(TooManyRedirects):
         await client.request("GET", "https://example.org/multiple_redirects?count=21")
+
+
+@pytest.mark.asyncio
+async def test_redirect_loop():
+    client = RedirectAdapter(MockDispatch())
+    with pytest.raises(RedirectLoop):
+        await client.request("GET", "https://example.org/redirect_loop")