Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from types import TracebackType
from typing import Any, TypeAlias

import anyio
import httpx
from pydantic import BaseModel, Field
from typing_extensions import Self
Expand Down Expand Up @@ -165,10 +164,9 @@ async def __aexit__(
if self._owns_exit_stack:
await self._exit_stack.aclose()

# Concurrently close session stacks.
async with anyio.create_task_group() as tg:
for exit_stack in self._session_exit_stacks.values():
tg.start_soon(exit_stack.aclose)
# Sequentially close session stacks to preserve AnyIO task contexts.
for exit_stack in list(self._session_exit_stacks.values()):
await exit_stack.aclose()

@property
def sessions(self) -> list[mcp.ClientSession]:
Expand Down
118 changes: 118 additions & 0 deletions tests/client/test_session_group.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import contextlib
import socket
import sys
from typing import cast
from unittest import mock

import httpx
import pytest

if sys.version_info >= (3, 11):
from builtins import BaseExceptionGroup, ExceptionGroup
else:
from exceptiongroup import BaseExceptionGroup, ExceptionGroup

import mcp
from mcp import types
from mcp.client.session_group import (
Expand Down Expand Up @@ -385,3 +393,113 @@ async def test_client_session_group_establish_session_parameterized(
# 3. Assert returned values
assert returned_server_info is mock_initialize_result.server_info
assert returned_session is mock_entered_session


def _free_tcp_port() -> int:
"""Return a TCP port number not currently bound on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()[1]


def _is_cancel_scope_runtime_error(exc: BaseException | None) -> bool:
"""Walk an exception chain looking for AnyIO cancel-scope RuntimeError."""
seen: set[int] = set()

def _walk(current: BaseException | None) -> bool:
if current is None or id(current) in seen:
return False
seen.add(id(current))

if isinstance(current, RuntimeError) and "cancel scope" in str(current).lower():
return True
if isinstance(current, BaseExceptionGroup):
group = cast("BaseExceptionGroup[BaseException]", current)
return any(_walk(child) for child in group.exceptions)
return _walk(current.__cause__) or _walk(current.__context__)

return _walk(exc)


@pytest.mark.anyio
async def test_unreachable_streamable_http_error_is_catchable() -> None:
"""Unreachable streamable-http servers raise catchable connection errors."""
port = _free_tcp_port()
server_params = StreamableHttpParameters(url=f"http://127.0.0.1:{port}/mcp/")

caught: BaseException | None = None

try:
async with ClientSessionGroup() as group:
try:
await group.connect_to_server(server_params)
except BaseException as inner: # noqa: BLE001
caught = inner
except BaseException as outer: # noqa: BLE001
caught = outer

assert caught is not None, (
"Expected to catch a connection error against an unreachable "
"streamable-http server, but no exception was raised."
)
assert not _is_cancel_scope_runtime_error(caught), (
"Regression of #915: connection error against an unreachable "
"streamable-http server was masked by an anyio cancel-scope "
f"RuntimeError. Got: {type(caught).__name__}: {caught}"
)


def test_is_cancel_scope_runtime_error_detected() -> None:
exc = RuntimeError("Attempted to exit cancel scope in a different task")

assert _is_cancel_scope_runtime_error(exc)


def test_is_cancel_scope_runtime_error_in_group_detected() -> None:
exc = ExceptionGroup(
"outer",
[ValueError("other"), RuntimeError("Attempted to exit cancel scope in a different task")],
)

assert _is_cancel_scope_runtime_error(exc)


def test_is_cancel_scope_non_runtime_error_not_detected() -> None:
assert not _is_cancel_scope_runtime_error(ValueError("cancel scope was mentioned"))


def test_is_cancel_scope_none_is_false() -> None:
assert not _is_cancel_scope_runtime_error(None)


def test_is_cancel_scope_in_cause_chain() -> None:
exc = ValueError("outer")
exc.__cause__ = RuntimeError("Attempted to exit cancel scope in a different task")

assert _is_cancel_scope_runtime_error(exc)


@pytest.mark.anyio
async def test_session_group_with_external_exit_stack(
mock_exit_stack: mock.MagicMock,
) -> None:
"""External exit stacks remain caller-managed."""
group = ClientSessionGroup(exit_stack=mock_exit_stack)

async with group:
pass

mock_exit_stack.__aenter__.assert_not_called()
mock_exit_stack.aclose.assert_not_called()


@pytest.mark.anyio
async def test_session_group_teardown_closes_session_stacks() -> None:
"""__aexit__ closes every session-level exit stack sequentially."""
session = mock.MagicMock(spec=mcp.ClientSession)
session_stack = mock.AsyncMock()

async with ClientSessionGroup() as group:
group._session_exit_stacks[session] = session_stack

session_stack.aclose.assert_awaited_once()
Loading