import asyncio
+import contextlib
import functools
import socket
import traceback
class WebSocketBaseTestCase(AsyncHTTPTestCase):
+ def setUp(self):
+ super().setUp()
+ self.conns_to_close = []
+
+ def tearDown(self):
+ for conn in self.conns_to_close:
+ conn.close()
+ super().tearDown()
+
@gen.coroutine
def ws_connect(self, path, **kwargs):
ws = yield websocket_connect(
"ws://127.0.0.1:%d%s" % (self.get_http_port(), path), **kwargs
)
+ self.conns_to_close.append(ws)
raise gen.Return(ws)
@gen_test
def test_websocket_close_buffered_data(self):
- ws = yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port())
- ws.write_message("hello")
- ws.write_message("world")
- # Close the underlying stream.
- ws.stream.close()
+ with contextlib.closing(
+ (yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port()))
+ ) as ws:
+ ws.write_message("hello")
+ ws.write_message("world")
+ # Close the underlying stream.
+ ws.stream.close()
@gen_test
def test_websocket_headers(self):
# Ensure that arbitrary headers can be passed through websocket_connect.
- ws = yield websocket_connect(
- HTTPRequest(
- "ws://127.0.0.1:%d/header" % self.get_http_port(),
- headers={"X-Test": "hello"},
+ with contextlib.closing(
+ (
+ yield websocket_connect(
+ HTTPRequest(
+ "ws://127.0.0.1:%d/header" % self.get_http_port(),
+ headers={"X-Test": "hello"},
+ )
+ )
)
- )
- response = yield ws.read_message()
- self.assertEqual(response, "hello")
+ ) as ws:
+ response = yield ws.read_message()
+ self.assertEqual(response, "hello")
@gen_test
def test_websocket_header_echo(self):
# Ensure that headers can be returned in the response.
# Specifically, that arbitrary headers passed through websocket_connect
# can be returned.
- ws = yield websocket_connect(
- HTTPRequest(
- "ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
- headers={"X-Test-Hello": "hello"},
+ with contextlib.closing(
+ (
+ yield websocket_connect(
+ HTTPRequest(
+ "ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
+ headers={"X-Test-Hello": "hello"},
+ )
+ )
+ )
+ ) as ws:
+ self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
+ self.assertEqual(
+ ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
)
- )
- self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
- self.assertEqual(
- ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
- )
@gen_test
def test_server_close_reason(self):
url = "ws://127.0.0.1:%d/echo" % port
headers = {"Origin": "http://127.0.0.1:%d" % port}
- ws = yield websocket_connect(HTTPRequest(url, headers=headers))
- ws.write_message("hello")
- response = yield ws.read_message()
- self.assertEqual(response, "hello")
+ with contextlib.closing(
+ (yield websocket_connect(HTTPRequest(url, headers=headers)))
+ ) as ws:
+ ws.write_message("hello")
+ response = yield ws.read_message()
+ self.assertEqual(response, "hello")
@gen_test
def test_check_origin_valid_with_path(self):
url = "ws://127.0.0.1:%d/echo" % port
headers = {"Origin": "http://127.0.0.1:%d/something" % port}
- ws = yield websocket_connect(HTTPRequest(url, headers=headers))
- ws.write_message("hello")
- response = yield ws.read_message()
- self.assertEqual(response, "hello")
+ with contextlib.closing(
+ (yield websocket_connect(HTTPRequest(url, headers=headers)))
+ ) as ws:
+ ws.write_message("hello")
+ response = yield ws.read_message()
+ self.assertEqual(response, "hello")
@gen_test
def test_check_origin_invalid_partial_url(self):