diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 961021264..d900789bd 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -13,7 +13,6 @@ from types import TracebackType from typing import Any, TypeAlias -import anyio import httpx from pydantic import BaseModel, Field from typing_extensions import Self @@ -165,10 +164,9 @@ async def __aexit__( if self._owns_exit_stack: await self._exit_stack.aclose() - # Concurrently close session stacks. - async with anyio.create_task_group() as tg: - for exit_stack in self._session_exit_stacks.values(): - tg.start_soon(exit_stack.aclose) + # Sequentially close session stacks to preserve AnyIO task contexts. + for exit_stack in list(self._session_exit_stacks.values()): + await exit_stack.aclose() @property def sessions(self) -> list[mcp.ClientSession]: diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f3..ad46f4af0 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -1,9 +1,17 @@ import contextlib +import socket +import sys +from typing import cast from unittest import mock import httpx import pytest +if sys.version_info >= (3, 11): + from builtins import BaseExceptionGroup, ExceptionGroup +else: + from exceptiongroup import BaseExceptionGroup, ExceptionGroup + import mcp from mcp import types from mcp.client.session_group import ( @@ -385,3 +393,113 @@ async def test_client_session_group_establish_session_parameterized( # 3. Assert returned values assert returned_server_info is mock_initialize_result.server_info assert returned_session is mock_entered_session + + +def _free_tcp_port() -> int: + """Return a TCP port number not currently bound on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname()[1] + + +def _is_cancel_scope_runtime_error(exc: BaseException | None) -> bool: + """Walk an exception chain looking for AnyIO cancel-scope RuntimeError.""" + seen: set[int] = set() + + def _walk(current: BaseException | None) -> bool: + if current is None or id(current) in seen: + return False + seen.add(id(current)) + + if isinstance(current, RuntimeError) and "cancel scope" in str(current).lower(): + return True + if isinstance(current, BaseExceptionGroup): + group = cast("BaseExceptionGroup[BaseException]", current) + return any(_walk(child) for child in group.exceptions) + return _walk(current.__cause__) or _walk(current.__context__) + + return _walk(exc) + + +@pytest.mark.anyio +async def test_unreachable_streamable_http_error_is_catchable() -> None: + """Unreachable streamable-http servers raise catchable connection errors.""" + port = _free_tcp_port() + server_params = StreamableHttpParameters(url=f"http://127.0.0.1:{port}/mcp/") + + caught: BaseException | None = None + + try: + async with ClientSessionGroup() as group: + try: + await group.connect_to_server(server_params) + except BaseException as inner: # noqa: BLE001 + caught = inner + except BaseException as outer: # noqa: BLE001 + caught = outer + + assert caught is not None, ( + "Expected to catch a connection error against an unreachable " + "streamable-http server, but no exception was raised." + ) + assert not _is_cancel_scope_runtime_error(caught), ( + "Regression of #915: connection error against an unreachable " + "streamable-http server was masked by an anyio cancel-scope " + f"RuntimeError. Got: {type(caught).__name__}: {caught}" + ) + + +def test_is_cancel_scope_runtime_error_detected() -> None: + exc = RuntimeError("Attempted to exit cancel scope in a different task") + + assert _is_cancel_scope_runtime_error(exc) + + +def test_is_cancel_scope_runtime_error_in_group_detected() -> None: + exc = ExceptionGroup( + "outer", + [ValueError("other"), RuntimeError("Attempted to exit cancel scope in a different task")], + ) + + assert _is_cancel_scope_runtime_error(exc) + + +def test_is_cancel_scope_non_runtime_error_not_detected() -> None: + assert not _is_cancel_scope_runtime_error(ValueError("cancel scope was mentioned")) + + +def test_is_cancel_scope_none_is_false() -> None: + assert not _is_cancel_scope_runtime_error(None) + + +def test_is_cancel_scope_in_cause_chain() -> None: + exc = ValueError("outer") + exc.__cause__ = RuntimeError("Attempted to exit cancel scope in a different task") + + assert _is_cancel_scope_runtime_error(exc) + + +@pytest.mark.anyio +async def test_session_group_with_external_exit_stack( + mock_exit_stack: mock.MagicMock, +) -> None: + """External exit stacks remain caller-managed.""" + group = ClientSessionGroup(exit_stack=mock_exit_stack) + + async with group: + pass + + mock_exit_stack.__aenter__.assert_not_called() + mock_exit_stack.aclose.assert_not_called() + + +@pytest.mark.anyio +async def test_session_group_teardown_closes_session_stacks() -> None: + """__aexit__ closes every session-level exit stack sequentially.""" + session = mock.MagicMock(spec=mcp.ClientSession) + session_stack = mock.AsyncMock() + + async with ClientSessionGroup() as group: + group._session_exit_stacks[session] = session_stack + + session_stack.aclose.assert_awaited_once()