]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Merge commit from fork
authorshamoon <4887959+shamoon@users.noreply.github.com>
Fri, 12 Dec 2025 17:28:17 +0000 (09:28 -0800)
committerGitHub <noreply@github.com>
Fri, 12 Dec 2025 17:28:17 +0000 (09:28 -0800)
* Uses a custom transport to resolve the slim chance of a DNS rebinding affecting the webhook

* Fix WebhookTransport hostname resolution and validation

* Fix test failures

* Lint

* Keep all internal logic inside WebhookTransport

* Fix test failure

* Update handlers.py

* Update handlers.py

---------

Co-authored-by: Trenton H <797416+stumpylog@users.noreply.github.com>
src/documents/tests/test_workflows.py
src/documents/workflows/webhooks.py

index e606bc5a072a19439e3981f7fe63c2e6e1475f47..249183b6e85fe6fb02e2e82b73b2d694820d53db 100644 (file)
@@ -17,6 +17,7 @@ from django.utils import timezone
 from guardian.shortcuts import assign_perm
 from guardian.shortcuts import get_groups_with_perms
 from guardian.shortcuts import get_users_with_perms
+from httpx import ConnectError
 from httpx import HTTPError
 from httpx import HTTPStatusError
 from pytest_httpx import HTTPXMock
@@ -3428,7 +3429,7 @@ class TestWorkflows(
             expected_str = "Error occurred parsing webhook headers"
             self.assertIn(expected_str, cm.output[1])
 
-    @mock.patch("httpx.post")
+    @mock.patch("httpx.Client.post")
     def test_workflow_webhook_send_webhook_task(self, mock_post):
         mock_post.return_value = mock.Mock(
             status_code=200,
@@ -3449,8 +3450,6 @@ class TestWorkflows(
                 content="Test message",
                 headers={},
                 files=None,
-                follow_redirects=False,
-                timeout=5,
             )
 
             expected_str = "Webhook sent to http://paperless-ngx.com"
@@ -3468,11 +3467,9 @@ class TestWorkflows(
                 data={"message": "Test message"},
                 headers={},
                 files=None,
-                follow_redirects=False,
-                timeout=5,
             )
 
-    @mock.patch("httpx.post")
+    @mock.patch("httpx.Client.post")
     def test_workflow_webhook_send_webhook_retry(self, mock_http):
         mock_http.return_value.raise_for_status = mock.Mock(
             side_effect=HTTPStatusError(
@@ -3668,7 +3665,7 @@ class TestWebhookSecurity:
             - ValueError is raised
         """
         resolve_to("127.0.0.1")
-        with pytest.raises(ValueError):
+        with pytest.raises(ConnectError):
             send_webhook(
                 "http://paperless-ngx.com",
                 data="",
@@ -3698,7 +3695,8 @@ class TestWebhookSecurity:
         )
 
         req = httpx_mock.get_request()
-        assert req.url.host == "paperless-ngx.com"
+        assert req.url.host == "52.207.186.75"
+        assert req.headers["host"] == "paperless-ngx.com"
 
     def test_follow_redirects_disabled(self, httpx_mock: HTTPXMock, resolve_to):
         """
index c7bb9f7c24ed6c9cdf6c78462be21e71ff3bbe89..49fb09f6d01ce3cf76d507a1df4416412c08b631 100644 (file)
@@ -10,26 +10,98 @@ from django.conf import settings
 logger = logging.getLogger("paperless.workflows.webhooks")
 
 
-def _is_public_ip(ip: str) -> bool:
-    try:
-        obj = ipaddress.ip_address(ip)
-        return not (
-            obj.is_private
-            or obj.is_loopback
-            or obj.is_link_local
-            or obj.is_multicast
-            or obj.is_unspecified
+class WebhookTransport(httpx.HTTPTransport):
+    """
+    Transport that resolves/validates hostnames and rewrites to a vetted IP
+    while keeping Host/SNI as the original hostname.
+    """
+
+    def __init__(
+        self,
+        hostname: str,
+        *args,
+        allow_internal: bool = False,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.hostname = hostname
+        self.allow_internal = allow_internal
+
+    def handle_request(self, request: httpx.Request) -> httpx.Response:
+        hostname = request.url.host
+
+        if not hostname:
+            raise httpx.ConnectError("No hostname in request URL")
+
+        try:
+            addr_info = socket.getaddrinfo(hostname, None)
+        except socket.gaierror as e:
+            raise httpx.ConnectError(f"Could not resolve hostname: {hostname}") from e
+
+        ips = [info[4][0] for info in addr_info if info and info[4]]
+        if not ips:
+            raise httpx.ConnectError(f"Could not resolve hostname: {hostname}")
+
+        if not self.allow_internal:
+            for ip_str in ips:
+                if not WebhookTransport.is_public_ip(ip_str):
+                    raise httpx.ConnectError(
+                        f"Connection blocked: {hostname} resolves to a non-public address",
+                    )
+
+        ip_str = ips[0]
+        formatted_ip = self._format_ip_for_url(ip_str)
+
+        new_headers = httpx.Headers(request.headers)
+        if "host" in new_headers:
+            del new_headers["host"]
+        new_headers["Host"] = hostname
+        new_url = request.url.copy_with(host=formatted_ip)
+
+        request = httpx.Request(
+            method=request.method,
+            url=new_url,
+            headers=new_headers,
+            content=request.content,
+            extensions=request.extensions,
         )
-    except ValueError:  # pragma: no cover
-        return False
+        request.extensions["sni_hostname"] = hostname
 
+        return super().handle_request(request)
 
-def _resolve_first_ip(host: str) -> str | None:
-    try:
-        info = socket.getaddrinfo(host, None)
-        return info[0][4][0] if info else None
-    except Exception:  # pragma: no cover
-        return None
+    def _format_ip_for_url(self, ip: str) -> str:
+        """
+        Format IP address for use in URL (wrap IPv6 in brackets)
+        """
+        try:
+            ip_obj = ipaddress.ip_address(ip)
+            if ip_obj.version == 6:
+                return f"[{ip}]"
+            return ip
+        except ValueError:
+            return ip
+
+    @staticmethod
+    def is_public_ip(ip: str | int) -> bool:
+        try:
+            obj = ipaddress.ip_address(ip)
+            return not (
+                obj.is_private
+                or obj.is_loopback
+                or obj.is_link_local
+                or obj.is_multicast
+                or obj.is_unspecified
+            )
+        except ValueError:  # pragma: no cover
+            return False
+
+    @staticmethod
+    def resolve_first_ip(host: str) -> str | None:
+        try:
+            info = socket.getaddrinfo(host, None)
+            return info[0][4][0] if info else None
+        except Exception:  # pragma: no cover
+            return None
 
 
 @shared_task(
@@ -59,12 +131,10 @@ def send_webhook(
         logger.warning("Webhook blocked: port not permitted")
         raise ValueError("Destination port not permitted.")
 
-    ip = _resolve_first_ip(p.hostname)
-    if not ip or (
-        not _is_public_ip(ip) and not settings.WEBHOOKS_ALLOW_INTERNAL_REQUESTS
-    ):
-        logger.warning("Webhook blocked: destination not allowed")
-        raise ValueError("Destination host is not allowed.")
+    transport = WebhookTransport(
+        hostname=p.hostname,
+        allow_internal=settings.WEBHOOKS_ALLOW_INTERNAL_REQUESTS,
+    )
 
     try:
         post_args = {
@@ -73,8 +143,6 @@ def send_webhook(
                 k: v for k, v in (headers or {}).items() if k.lower() != "host"
             },
             "files": files or None,
-            "timeout": 5.0,
-            "follow_redirects": False,
         }
         if as_json:
             post_args["json"] = data
@@ -83,14 +151,21 @@ def send_webhook(
         else:
             post_args["content"] = data
 
-        httpx.post(
-            **post_args,
-        ).raise_for_status()
-        logger.info(
-            f"Webhook sent to {url}",
-        )
+        with httpx.Client(
+            transport=transport,
+            timeout=5.0,
+            follow_redirects=False,
+        ) as client:
+            client.post(
+                **post_args,
+            ).raise_for_status()
+            logger.info(
+                f"Webhook sent to {url}",
+            )
     except Exception as e:
         logger.error(
             f"Failed attempt sending webhook to {url}: {e}",
         )
         raise e
+    finally:
+        transport.close()