Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/mcp/server/mcpserver/utilities/func_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from collections.abc import Awaitable, Callable, Sequence
from itertools import chain
from types import GenericAlias
from types import GenericAlias, NoneType
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints

import anyio
Expand Down Expand Up @@ -148,7 +148,7 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
continue

field_info = key_to_field_info[data_key]
if isinstance(data_value, str) and field_info.annotation is not str:
if isinstance(data_value, str) and _should_pre_parse_json(field_info.annotation):
try:
pre_parsed = json.loads(data_value)
except json.JSONDecodeError:
Expand All @@ -167,6 +167,22 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
)


def _is_simple_scalar_annotation(annotation: Any) -> bool:
return annotation in {str, int, float, bool, NoneType}


def _should_pre_parse_json(annotation: Any) -> bool:
"""Return whether string input for this annotation should be JSON-decoded."""
if annotation is str:
return False

origin = get_origin(annotation)
if is_union_origin(origin):
return not all(_is_simple_scalar_annotation(arg) for arg in get_args(annotation))

return True


def func_metadata(
func: Callable[..., Any],
skip_names: Sequence[str] = (),
Expand Down
35 changes: 35 additions & 0 deletions tests/server/mcpserver/test_func_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,41 @@ def handle_json_payload(payload: str, strict_mode: bool = False) -> str:
assert result == f"Handled payload of length {len(json_array_payload)}"


@pytest.mark.anyio
async def test_optional_str_annotation_preserves_json_string():
def update_task(task_id: str | None = None) -> str:
assert isinstance(task_id, str)
return task_id

meta = func_metadata(update_task)

uuid = "3400e37e-b251-49d9-91b0-f8dd8602ff7e"
json_payload = '{"id": "3400e37e-b251-49d9-91b0-f8dd8602ff7e"}'

assert meta.pre_parse_json({"task_id": uuid})["task_id"] == uuid
assert meta.pre_parse_json({"task_id": json_payload})["task_id"] == json_payload
assert meta.pre_parse_json({"task_id": "[1, 2]"})["task_id"] == "[1, 2]"

result = await meta.call_fn_with_arg_validation(
update_task,
fn_is_async=False,
arguments_to_validate={"task_id": json_payload},
arguments_to_pass_directly=None,
)

assert result == json_payload


def test_str_or_list_still_pre_parses_lists():
def func_with_str_or_list(value: str | list[str]): # pragma: no cover
return value

meta = func_metadata(func_with_str_or_list)

assert meta.pre_parse_json({"value": "hello"})["value"] == "hello"
assert meta.pre_parse_json({"value": '["hello", "world"]'})["value"] == ["hello", "world"]


# Tests for structured output functionality


Expand Down
Loading