diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index 4d39f37046..47314d6d07 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -483,6 +483,11 @@ def test_write_after_close(self): with self.assertRaises(WebSocketClosedError): ws.write_message("hello") + @gen_test + def test_reference_self_after_been_closed(self): + ws = yield self.ws_connect("/close_reason") + self.assertIs(ws.connect_future, None) + @gen_test def test_async_prepare(self): # Previously, an async prepare method triggered a bug that would diff --git a/tornado/websocket.py b/tornado/websocket.py index fbfd700887..573258a434 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -1360,7 +1360,9 @@ def __init__( subprotocols: Optional[List[str]] = None, resolver: Optional[Resolver] = None, ) -> None: - self.connect_future = Future() # type: Future[WebSocketClientConnection] + self.connect_future = ( + Future() + ) # type: Union[Future[WebSocketClientConnection], None] self.read_queue = Queue(1) # type: Queue[Union[None, str, bytes]] self.key = base64.b64encode(os.urandom(16)) self._on_message_callback = on_message_callback @@ -1437,11 +1439,12 @@ def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> Non self.protocol = None # type: ignore def on_connection_close(self) -> None: - if not self.connect_future.done(): + if self.connect_future and not self.connect_future.done(): self.connect_future.set_exception(StreamClosedError()) self._on_message(None) self.tcp_client.close() super().on_connection_close() + self.connect_future = None def on_ws_connection_close( self, close_code: Optional[int] = None, close_reason: Optional[str] = None @@ -1451,7 +1454,7 @@ def on_ws_connection_close( self.on_connection_close() def _on_http_response(self, response: httpclient.HTTPResponse) -> None: - if not self.connect_future.done(): + if self.connect_future and not self.connect_future.done(): if response.error: self.connect_future.set_exception(response.error) else: @@ -1487,7 +1490,8 @@ async def headers_received( # ability to see exceptions. self.final_callback = None # type: ignore - future_set_result_unless_cancelled(self.connect_future, self) + if self.connect_future: + future_set_result_unless_cancelled(self.connect_future, self) def write_message( self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False @@ -1664,6 +1668,8 @@ def websocket_connect( subprotocols=subprotocols, resolver=resolver, ) - if callback is not None: - IOLoop.current().add_future(conn.connect_future, callback) - return conn.connect_future + if conn.connect_future: + if callback is not None: + IOLoop.current().add_future(conn.connect_future, callback) + return conn.connect_future + raise WebSocketError("Initialize websocket client")