From: Tom Christie Date: Wed, 19 Jun 2019 09:08:19 +0000 (+0100) Subject: request.scope and request.state X-Git-Tag: 0.12.2~9^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=refs%2Fpull%2F557%2Fhead;p=thirdparty%2Fstarlette.git request.scope and request.state --- diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index ea5afb21..67d35cc8 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -29,7 +29,7 @@ class BaseHTTPMiddleware: loop = asyncio.get_event_loop() queue = asyncio.Queue() # type: asyncio.Queue - scope = dict(request) + scope = request.scope receive = request.receive send = queue.put diff --git a/starlette/requests.py b/starlette/requests.py index c17a6080..be05f291 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -26,12 +26,11 @@ class State(object): self._state[key] = value def __getattr__(self, key: typing.Any) -> typing.Any: - if key in self._state: + try: return self._state[key] - else: - raise AttributeError( - "'{}' object has no attribute '{}'".format(self.__class__.__name__, key) - ) + except KeyError: + message = "'{}' object has no attribute '{}'" + raise AttributeError(message.format(self.__class__.__name__, key)) def __delattr__(self, key: typing.Any) -> None: del self._state[key] @@ -45,45 +44,42 @@ class HTTPConnection(Mapping): def __init__(self, scope: Scope, receive: Receive = None) -> None: assert scope["type"] in ("http", "websocket") - self._scope = scope - - # Ensure 'state' has an empty dict if it's not already populated. - self._scope.setdefault("state", {}) + self.scope = scope def __getitem__(self, key: str) -> str: - return self._scope[key] + return self.scope[key] def __iter__(self) -> typing.Iterator[str]: - return iter(self._scope) + return iter(self.scope) def __len__(self) -> int: - return len(self._scope) + return len(self.scope) @property def app(self) -> typing.Any: - return self._scope["app"] + return self.scope["app"] @property def url(self) -> URL: if not hasattr(self, "_url"): - self._url = URL(scope=self._scope) + self._url = URL(scope=self.scope) return self._url @property def headers(self) -> Headers: if not hasattr(self, "_headers"): - self._headers = Headers(scope=self._scope) + self._headers = Headers(scope=self.scope) return self._headers @property def query_params(self) -> QueryParams: if not hasattr(self, "_query_params"): - self._query_params = QueryParams(self._scope["query_string"]) + self._query_params = QueryParams(self.scope["query_string"]) return self._query_params @property def path_params(self) -> dict: - return self._scope.get("path_params", {}) + return self.scope.get("path_params", {}) @property def cookies(self) -> typing.Dict[str, str]: @@ -100,39 +96,41 @@ class HTTPConnection(Mapping): @property def client(self) -> Address: - host, port = self._scope.get("client") or (None, None) + host, port = self.scope.get("client") or (None, None) return Address(host=host, port=port) @property def session(self) -> dict: assert ( - "session" in self._scope + "session" in self.scope ), "SessionMiddleware must be installed to access request.session" - return self._scope["session"] + return self.scope["session"] @property def auth(self) -> typing.Any: assert ( - "auth" in self._scope + "auth" in self.scope ), "AuthenticationMiddleware must be installed to access request.auth" - return self._scope["auth"] + return self.scope["auth"] @property def user(self) -> typing.Any: assert ( - "user" in self._scope + "user" in self.scope ), "AuthenticationMiddleware must be installed to access request.user" - return self._scope["user"] + return self.scope["user"] @property def state(self) -> State: if not hasattr(self, "_state"): + # Ensure 'state' has an empty dict if it's not already populated. + self.scope.setdefault("state", {}) # Create a state instance with a reference to the dict in which it should store info - self._state = State(self._scope["state"]) + self._state = State(self.scope["state"]) return self._state def url_for(self, name: str, **path_params: typing.Any) -> str: - router = self._scope["router"] + router = self.scope["router"] url_path = router.url_path_for(name, **path_params) return url_path.make_absolute_url(base_url=self.url) @@ -151,7 +149,7 @@ class Request(HTTPConnection): @property def method(self) -> str: - return self._scope["method"] + return self.scope["method"] @property def receive(self) -> Receive: diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index a9d7d423..09a1d0b3 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -93,19 +93,23 @@ def test_state_data_across_multiple_middlewares(): class aMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): request.state.foo = expected_value1 + print("a", dict(request)) response = await call_next(request) return response class bMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): request.state.bar = expected_value2 + print("b", dict(request)) response = await call_next(request) response.headers["X-State-Foo"] = request.state.foo return response class cMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): + print("c", dict(request)) response = await call_next(request) + print("c", dict(request)) response.headers["X-State-Bar"] = request.state.bar return response diff --git a/tests/test_requests.py b/tests/test_requests.py index 20b456f9..03c74c45 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -174,13 +174,8 @@ def test_request_scope_interface(): """ request = Request({"type": "http", "method": "GET", "path": "/abc/"}) assert request["method"] == "GET" - assert dict(request) == { - "type": "http", - "method": "GET", - "path": "/abc/", - "state": {}, - } - assert len(request) == 4 + assert dict(request) == {"type": "http", "method": "GET", "path": "/abc/"} + assert len(request) == 3 def test_request_without_setting_receive(): @@ -249,19 +244,14 @@ def test_request_state_object(): scope = {"state": {"old": "foo"}} s = State(scope["state"]) - assert s._state == scope["state"] - assert getattr(s, "_state") == scope["state"] s.new = "value" assert s.new == "value" - assert s._state["new"] == "value" # test if inner _state dict is updated. del s.new - try: - assert s.new == "value" # will bombed with AttributeError - except AttributeError as e: - assert str(e) == "'State' object has no attribute 'new'" + with pytest.raises(AttributeError): + s.new def test_request_state(): diff --git a/tests/test_responses.py b/tests/test_responses.py index c7ca0ac0..f90bde48 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -8,11 +8,11 @@ from starlette.background import BackgroundTask from starlette.requests import Request from starlette.responses import ( FileResponse, + JSONResponse, RedirectResponse, Response, StreamingResponse, UJSONResponse, - JSONResponse, ) from starlette.testclient import TestClient diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 1c0519d6..41e77237 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -281,10 +281,5 @@ def test_websocket_scope_interface(): send=mock_send, ) assert websocket["type"] == "websocket" - assert dict(websocket) == { - "type": "websocket", - "path": "/abc/", - "headers": [], - "state": {}, - } - assert len(websocket) == 4 + assert dict(websocket) == {"type": "websocket", "path": "/abc/", "headers": []} + assert len(websocket) == 3