diff --git a/src/calibre/srv/tests/web_sockets.py b/src/calibre/srv/tests/web_sockets.py index 6c31ca566c..b470c1e30f 100644 --- a/src/calibre/srv/tests/web_sockets.py +++ b/src/calibre/srv/tests/web_sockets.py @@ -165,6 +165,12 @@ class WSClient: reason = reason.encode('utf-8') self.write_frame(1, CLOSE, struct.pack(b'!H', code) + reason) + def __enter__(self): + return self + + def __exit__(self, *a): + self.socket.close() + class WSTestServer(TestServer): @@ -184,42 +190,42 @@ class WSTestServer(TestServer): class WebSocketTest(BaseTest): def simple_test(self, server, msgs, expected=(), close_code=NORMAL_CLOSE, send_close=True, close_reason=b'NORMAL CLOSE', ignore_send_failures=False): - client = server.connect() - for msg in msgs: - try: - if isinstance(msg, dict): - client.write_frame(**msg) - else: - client.write_message(msg) - except Exception: - if not ignore_send_failures: - raise + with server.connect() as client: + for msg in msgs: + try: + if isinstance(msg, dict): + client.write_frame(**msg) + else: + client.write_message(msg) + except Exception: + if not ignore_send_failures: + raise - expected_messages, expected_controls = [], [] - for ex in expected: - if isinstance(ex, str): - ex = TEXT, ex - elif isinstance(ex, bytes): - ex = BINARY, ex - elif isinstance(ex, numbers.Integral): - ex = ex, b'' - if ex[0] in CONTROL_CODES: - expected_controls.append(ex) - else: - expected_messages.append(ex) - if send_close: - client.write_close(close_code, close_reason) - try: - messages, control_frames = client.read_messages() - except ConnectionAbortedError: - if expected_messages or expected_controls or send_close: - raise - return - self.ae(expected_messages, messages) - self.assertGreaterEqual(len(control_frames), 1) - self.ae(expected_controls, control_frames[:-1]) - self.ae(control_frames[-1][0], CLOSE) - self.ae(close_code, struct.unpack_from(b'!H', control_frames[-1][1], 0)[0]) + expected_messages, expected_controls = [], [] + for ex in expected: + if isinstance(ex, str): + ex = TEXT, ex + elif isinstance(ex, bytes): + ex = BINARY, ex + elif isinstance(ex, numbers.Integral): + ex = ex, b'' + if ex[0] in CONTROL_CODES: + expected_controls.append(ex) + else: + expected_messages.append(ex) + if send_close: + client.write_close(close_code, close_reason) + try: + messages, control_frames = client.read_messages() + except ConnectionAbortedError: + if expected_messages or expected_controls or send_close: + raise + return + self.ae(expected_messages, messages) + self.assertGreaterEqual(len(control_frames), 1) + self.ae(expected_controls, control_frames[:-1]) + self.ae(control_frames[-1][0], CLOSE) + self.ae(close_code, struct.unpack_from(b'!H', control_frames[-1][1], 0)[0]) def test_websocket_basic(self): 'Test basic interaction with the websocket server'