]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
fix handling of scope's client in Request (#1462)
authorAdrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
Tue, 1 Feb 2022 11:55:35 +0000 (05:55 -0600)
committerGitHub <noreply@github.com>
Tue, 1 Feb 2022 11:55:35 +0000 (05:55 -0600)
starlette/datastructures.py
starlette/requests.py
tests/test_requests.py
tests/test_testclient.py

index ae2d9803a3510a330b79f6d8a4a2be9db08068c3..1a8b965e8a2c169ae1bdec4246cdbe7de44ec167 100644 (file)
@@ -9,8 +9,8 @@ from starlette.types import Scope
 
 
 class Address(typing.NamedTuple):
-    host: typing.Optional[str]
-    port: typing.Optional[int]
+    host: str
+    port: int
 
 
 class URL:
index bdfcfcbc126cd081d23aef3f9f03298f6e5063c0..e3c91e284150e8cd4ae1b0597b55de832d02c55e 100644 (file)
@@ -134,9 +134,12 @@ class HTTPConnection(Mapping):
         return self._cookies
 
     @property
-    def client(self) -> Address:
-        host, port = self.scope.get("client") or (None, None)
-        return Address(host=host, port=port)
+    def client(self) -> typing.Optional[Address]:
+        # client is a 2 item tuple of (host, port), None or missing
+        host_port = self.scope.get("client")
+        if host_port is not None:
+            return Address(*host_port)
+        return None
 
     @property
     def session(self) -> dict:
index e535e779031f8cf0de1d3529e3f569db6f251fb6..799e61f805d026cf2755f2e8061fda3b013e0391 100644 (file)
@@ -1,8 +1,12 @@
+from typing import Optional
+
 import anyio
 import pytest
 
+from starlette.datastructures import Address
 from starlette.requests import ClientDisconnect, Request, State
 from starlette.responses import JSONResponse, PlainTextResponse, Response
+from starlette.types import Scope
 
 
 def test_request_url(test_client_factory):
@@ -52,17 +56,18 @@ def test_request_headers(test_client_factory):
     }
 
 
-def test_request_client(test_client_factory):
-    async def app(scope, receive, send):
-        request = Request(scope, receive)
-        response = JSONResponse(
-            {"host": request.client.host, "port": request.client.port}
-        )
-        await response(scope, receive, send)
-
-    client = test_client_factory(app)
-    response = client.get("/")
-    assert response.json() == {"host": "testclient", "port": 50000}
+@pytest.mark.parametrize(
+    "scope,expected_client",
+    [
+        ({"client": ["client", 42]}, Address("client", 42)),
+        ({"client": None}, None),
+        ({}, None),
+    ],
+)
+def test_request_client(scope: Scope, expected_client: Optional[Address]):
+    scope.update({"type": "http"})  # required by Request's constructor
+    client = Request(scope).client
+    assert client == expected_client
 
 
 def test_request_body(test_client_factory):
index 8c066678966c7274aa25839a8086d5e9d3f62a5c..eb47a59c2b14269f352323482cf086d98442f1a2 100644 (file)
@@ -229,3 +229,16 @@ def test_websocket_blocking_receive(test_client_factory):
     with client.websocket_connect("/") as websocket:
         data = websocket.receive_json()
         assert data == {"message": "test"}
+
+
+def test_client(test_client_factory):
+    async def app(scope, receive, send):
+        client = scope.get("client")
+        assert client is not None
+        host, port = client
+        response = JSONResponse({"host": host, "port": port})
+        await response(scope, receive, send)
+
+    client = test_client_factory(app)
+    response = client.get("/")
+    assert response.json() == {"host": "testclient", "port": 50000}