]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Default WebSocket accept message headers to an empty list (#1422)
authorRoman Vlasenko <56698047+Klavionik@users.noreply.github.com>
Sat, 22 Jan 2022 14:09:18 +0000 (17:09 +0300)
committerGitHub <noreply@github.com>
Sat, 22 Jan 2022 14:09:18 +0000 (15:09 +0100)
* If no extra headers are passed, set it to an empty list

* Test websocket.accept() with no additional headers

* Update starlette/websockets.py

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
* Update tests/test_websockets.py

Co-authored-by: Amin Alaee <mohammadamin.alaee@gmail.com>
* Update tests/test_websockets.py

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
Co-authored-by: Amin Alaee <mohammadamin.alaee@gmail.com>
starlette/websockets.py
tests/test_websockets.py

index 7632b28cf12839ab8a10f6f8d0435ab8d5876493..bf4cca83fa3c4998b1580110ad05ac0f893805cd 100644 (file)
@@ -74,6 +74,8 @@ class WebSocket(HTTPConnection):
         subprotocol: str = None,
         headers: typing.Iterable[typing.Tuple[bytes, bytes]] = None,
     ) -> None:
+        headers = headers or []
+
         if self.client_state == WebSocketState.CONNECTING:
             # If we haven't yet seen the 'connect' message, then wait for it first.
             await self.receive()
index bf0253309831d47963d4b3bd749dca2afe29fd89..f3242d115eb8c45bc33ea24dd04b6c8ef974010e 100644 (file)
@@ -315,6 +315,20 @@ def test_additional_headers(test_client_factory):
         assert websocket.extra_headers == [(b"additional", b"header")]
 
 
+def test_no_additional_headers(test_client_factory):
+    def app(scope):
+        async def asgi(receive, send):
+            websocket = WebSocket(scope, receive=receive, send=send)
+            await websocket.accept()
+            await websocket.close()
+
+        return asgi
+
+    client = test_client_factory(app)
+    with client.websocket_connect("/") as websocket:
+        assert websocket.extra_headers == []
+
+
 def test_websocket_exception(test_client_factory):
     def app(scope):
         async def asgi(receive, send):