]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
request.scope and request.state 557/head
authorTom Christie <tom@tomchristie.com>
Wed, 19 Jun 2019 09:08:19 +0000 (10:08 +0100)
committerTom Christie <tom@tomchristie.com>
Wed, 19 Jun 2019 09:08:19 +0000 (10:08 +0100)
starlette/middleware/base.py
starlette/requests.py
tests/middleware/test_base.py
tests/test_requests.py
tests/test_responses.py
tests/test_websockets.py

index ea5afb210de3dfb8279c1e9b3e7d38915df06e99..67d35cc865cd412074fdda8e6ea8d46659541d44 100644 (file)
@@ -29,7 +29,7 @@ class BaseHTTPMiddleware:
         loop = asyncio.get_event_loop()
         queue = asyncio.Queue()  # type: asyncio.Queue
 
-        scope = dict(request)
+        scope = request.scope
         receive = request.receive
         send = queue.put
 
index c17a60809b2923dba97b7b880535da1489c619b5..be05f291aaf9a0e64430055b1a10c9b628d313dd 100644 (file)
@@ -26,12 +26,11 @@ class State(object):
         self._state[key] = value
 
     def __getattr__(self, key: typing.Any) -> typing.Any:
-        if key in self._state:
+        try:
             return self._state[key]
-        else:
-            raise AttributeError(
-                "'{}' object has no attribute '{}'".format(self.__class__.__name__, key)
-            )
+        except KeyError:
+            message = "'{}' object has no attribute '{}'"
+            raise AttributeError(message.format(self.__class__.__name__, key))
 
     def __delattr__(self, key: typing.Any) -> None:
         del self._state[key]
@@ -45,45 +44,42 @@ class HTTPConnection(Mapping):
 
     def __init__(self, scope: Scope, receive: Receive = None) -> None:
         assert scope["type"] in ("http", "websocket")
-        self._scope = scope
-
-        # Ensure 'state' has an empty dict if it's not already populated.
-        self._scope.setdefault("state", {})
+        self.scope = scope
 
     def __getitem__(self, key: str) -> str:
-        return self._scope[key]
+        return self.scope[key]
 
     def __iter__(self) -> typing.Iterator[str]:
-        return iter(self._scope)
+        return iter(self.scope)
 
     def __len__(self) -> int:
-        return len(self._scope)
+        return len(self.scope)
 
     @property
     def app(self) -> typing.Any:
-        return self._scope["app"]
+        return self.scope["app"]
 
     @property
     def url(self) -> URL:
         if not hasattr(self, "_url"):
-            self._url = URL(scope=self._scope)
+            self._url = URL(scope=self.scope)
         return self._url
 
     @property
     def headers(self) -> Headers:
         if not hasattr(self, "_headers"):
-            self._headers = Headers(scope=self._scope)
+            self._headers = Headers(scope=self.scope)
         return self._headers
 
     @property
     def query_params(self) -> QueryParams:
         if not hasattr(self, "_query_params"):
-            self._query_params = QueryParams(self._scope["query_string"])
+            self._query_params = QueryParams(self.scope["query_string"])
         return self._query_params
 
     @property
     def path_params(self) -> dict:
-        return self._scope.get("path_params", {})
+        return self.scope.get("path_params", {})
 
     @property
     def cookies(self) -> typing.Dict[str, str]:
@@ -100,39 +96,41 @@ class HTTPConnection(Mapping):
 
     @property
     def client(self) -> Address:
-        host, port = self._scope.get("client") or (None, None)
+        host, port = self.scope.get("client") or (None, None)
         return Address(host=host, port=port)
 
     @property
     def session(self) -> dict:
         assert (
-            "session" in self._scope
+            "session" in self.scope
         ), "SessionMiddleware must be installed to access request.session"
-        return self._scope["session"]
+        return self.scope["session"]
 
     @property
     def auth(self) -> typing.Any:
         assert (
-            "auth" in self._scope
+            "auth" in self.scope
         ), "AuthenticationMiddleware must be installed to access request.auth"
-        return self._scope["auth"]
+        return self.scope["auth"]
 
     @property
     def user(self) -> typing.Any:
         assert (
-            "user" in self._scope
+            "user" in self.scope
         ), "AuthenticationMiddleware must be installed to access request.user"
-        return self._scope["user"]
+        return self.scope["user"]
 
     @property
     def state(self) -> State:
         if not hasattr(self, "_state"):
+            # Ensure 'state' has an empty dict if it's not already populated.
+            self.scope.setdefault("state", {})
             # Create a state instance with a reference to the dict in which it should store info
-            self._state = State(self._scope["state"])
+            self._state = State(self.scope["state"])
         return self._state
 
     def url_for(self, name: str, **path_params: typing.Any) -> str:
-        router = self._scope["router"]
+        router = self.scope["router"]
         url_path = router.url_path_for(name, **path_params)
         return url_path.make_absolute_url(base_url=self.url)
 
@@ -151,7 +149,7 @@ class Request(HTTPConnection):
 
     @property
     def method(self) -> str:
-        return self._scope["method"]
+        return self.scope["method"]
 
     @property
     def receive(self) -> Receive:
index a9d7d423af537ea191396c543a35c7361f83ef33..09a1d0b37310ab83ff181b4d815b834553f07e01 100644 (file)
@@ -93,19 +93,23 @@ def test_state_data_across_multiple_middlewares():
     class aMiddleware(BaseHTTPMiddleware):
         async def dispatch(self, request, call_next):
             request.state.foo = expected_value1
+            print("a", dict(request))
             response = await call_next(request)
             return response
 
     class bMiddleware(BaseHTTPMiddleware):
         async def dispatch(self, request, call_next):
             request.state.bar = expected_value2
+            print("b", dict(request))
             response = await call_next(request)
             response.headers["X-State-Foo"] = request.state.foo
             return response
 
     class cMiddleware(BaseHTTPMiddleware):
         async def dispatch(self, request, call_next):
+            print("c", dict(request))
             response = await call_next(request)
+            print("c", dict(request))
             response.headers["X-State-Bar"] = request.state.bar
             return response
 
index 20b456f95d64bfceb130d249c83d822bd85f2610..03c74c4515f3b189c0b25dd080c2a63894e779bb 100644 (file)
@@ -174,13 +174,8 @@ def test_request_scope_interface():
     """
     request = Request({"type": "http", "method": "GET", "path": "/abc/"})
     assert request["method"] == "GET"
-    assert dict(request) == {
-        "type": "http",
-        "method": "GET",
-        "path": "/abc/",
-        "state": {},
-    }
-    assert len(request) == 4
+    assert dict(request) == {"type": "http", "method": "GET", "path": "/abc/"}
+    assert len(request) == 3
 
 
 def test_request_without_setting_receive():
@@ -249,19 +244,14 @@ def test_request_state_object():
     scope = {"state": {"old": "foo"}}
 
     s = State(scope["state"])
-    assert s._state == scope["state"]
-    assert getattr(s, "_state") == scope["state"]
 
     s.new = "value"
     assert s.new == "value"
-    assert s._state["new"] == "value"  # test if inner _state dict is updated.
 
     del s.new
 
-    try:
-        assert s.new == "value"  # will bombed with AttributeError
-    except AttributeError as e:
-        assert str(e) == "'State' object has no attribute 'new'"
+    with pytest.raises(AttributeError):
+        s.new
 
 
 def test_request_state():
index c7ca0ac0a8b80dd210486848217f363749d69ddf..f90bde48b41d0243521e045f1dfe313d15974197 100644 (file)
@@ -8,11 +8,11 @@ from starlette.background import BackgroundTask
 from starlette.requests import Request
 from starlette.responses import (
     FileResponse,
+    JSONResponse,
     RedirectResponse,
     Response,
     StreamingResponse,
     UJSONResponse,
-    JSONResponse,
 )
 from starlette.testclient import TestClient
 
index 1c0519d620a5662a5381e5f52e5143b7f7bf92fa..41e77237ca6a06c679118c51d747277622abaa77 100644 (file)
@@ -281,10 +281,5 @@ def test_websocket_scope_interface():
         send=mock_send,
     )
     assert websocket["type"] == "websocket"
-    assert dict(websocket) == {
-        "type": "websocket",
-        "path": "/abc/",
-        "headers": [],
-        "state": {},
-    }
-    assert len(websocket) == 4
+    assert dict(websocket) == {"type": "websocket", "path": "/abc/", "headers": []}
+    assert len(websocket) == 3