diff --git a/ipykernel/kernelbase.py b/ipykernel/kernelbase.py index 6465751c..263ff209 100644 --- a/ipykernel/kernelbase.py +++ b/ipykernel/kernelbase.py @@ -428,6 +428,8 @@ async def shell_main(self, subshell_id: str | None): await to_thread.run_sync(self.shell_stop.wait) tg.cancel_scope.cancel() + await socket.stop() + async def process_shell(self, socket=None): # socket=None is valid if kernel subshells are not supported. try: diff --git a/ipykernel/subshell.py b/ipykernel/subshell.py index 180e9ecb..34aa6efe 100644 --- a/ipykernel/subshell.py +++ b/ipykernel/subshell.py @@ -32,11 +32,4 @@ async def create_pair_socket( self._pair_socket = zmq_anyio.Socket(context, zmq.PAIR) self._pair_socket.connect(address) self.start_soon(self._pair_socket.start) - - def run(self) -> None: - try: - super().run() - finally: - if self._pair_socket is not None: - self._pair_socket.close() - self._pair_socket = None + self.add_teardown_callback(self._pair_socket.stop) diff --git a/ipykernel/subshell_manager.py b/ipykernel/subshell_manager.py index f4f92d2d..4068f27e 100644 --- a/ipykernel/subshell_manager.py +++ b/ipykernel/subshell_manager.py @@ -116,11 +116,13 @@ def close(self) -> None: async def get_control_other_socket(self, thread: BaseThread) -> zmq_anyio.Socket: if not self._control_other_socket.started.is_set(): await thread.task_group.start(self._control_other_socket.start) + thread.add_teardown_callback(self._control_other_socket.stop) return self._control_other_socket async def get_control_shell_channel_socket(self, thread: BaseThread) -> zmq_anyio.Socket: if not self._control_shell_channel_socket.started.is_set(): await thread.task_group.start(self._control_shell_channel_socket.start) + thread.add_teardown_callback(self._control_shell_channel_socket.stop) return self._control_shell_channel_socket def get_other_socket(self, subshell_id: str | None) -> zmq_anyio.Socket: @@ -281,6 +283,8 @@ async def _listen_for_subshell_reply( # Subshell no longer exists so exit gracefully return raise + finally: + await shell_channel_socket.stop() async def _process_control_request( self, request: dict[str, t.Any], subshell_task: t.Any diff --git a/ipykernel/thread.py b/ipykernel/thread.py index a66cb2a4..185fd8d8 100644 --- a/ipykernel/thread.py +++ b/ipykernel/thread.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Awaitable +from inspect import isawaitable from queue import Queue from threading import Event, Thread from typing import Any, Callable @@ -26,6 +27,7 @@ def __init__(self, **kwargs): self.is_pydev_daemon_thread = True self._tasks: Queue[tuple[str, Callable[[], Awaitable[Any]]] | None] = Queue() self._result: Queue[Any] = Queue() + self._teardown_callbacks: list[Callable[[], Any] | Callable[[], Awaitable[Any]]] = [] self._exception: Exception | None = None @property @@ -47,6 +49,9 @@ def run_sync(self, func: Callable[..., Any]) -> Any: self._tasks.put(("run_sync", func)) return self._result.get() + def add_teardown_callback(self, func: Callable[[], Any] | Callable[[], Awaitable[Any]]) -> None: + self._teardown_callbacks.append(func) + def run(self) -> None: """Run the thread.""" try: @@ -55,24 +60,37 @@ def run(self) -> None: self._exception = exc async def _main(self) -> None: - async with create_task_group() as tg: - self._task_group = tg - self.started.set() - while True: - task = await to_thread.run_sync(self._tasks.get) - if task is None: - break - func, arg = task - if func == "start_soon": - tg.start_soon(arg) - elif func == "run_async": - res = await arg - self._result.put(res) - else: # func == "run_sync" - res = arg() - self._result.put(res) - - tg.cancel_scope.cancel() + try: + async with create_task_group() as tg: + self._task_group = tg + self.started.set() + while True: + task = await to_thread.run_sync(self._tasks.get) + if task is None: + break + func, arg = task + if func == "start_soon": + tg.start_soon(arg) + elif func == "run_async": + res = await arg + self._result.put(res) + else: # func == "run_sync" + res = arg() + self._result.put(res) + + tg.cancel_scope.cancel() + finally: + exception = None + for teardown_callback in self._teardown_callbacks[::-1]: + try: + res = teardown_callback() + if isawaitable(res): + await res + except Exception as exc: + if exception is None: + exception = exc + if exception is not None: + raise exception def stop(self) -> None: """Stop the thread. diff --git a/pyproject.toml b/pyproject.toml index 1853544f..4bd8721c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ "psutil>=5.7", "packaging>=22", "anyio>=4.8.0,<5.0.0", - "zmq-anyio >=0.3.6", + "zmq-anyio >=0.3.9", ] [project.urls]