Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions rigging/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,13 +1845,12 @@ async def _step( # noqa: PLR0915, PLR0912
*[state.ready_event.wait() for state in states if not state.completed],
)

# TODO(nick): Are we good to throw exceptions here?
for task in tasks:
for state, task in zip(states, tasks, strict=True):
if task.done() and (exception := task.exception()):
raise exception
state.chat.error = exception
state.chat.failed = True

for state in states:
if state.ready_event.is_set() and state.step:
elif state.ready_event.is_set() and state.step:
step = state.step.with_parent(current_step)

if step.depth > max_depth:
Expand All @@ -1875,10 +1874,12 @@ async def _step( # noqa: PLR0915, PLR0912
for task in tasks:
if not task.done():
task.cancel()
await asyncio.gather(*tasks) # TODO(nick): return_exceptions=True ?
await asyncio.gather(*tasks, return_exceptions=True)

chats = ChatList([state.chat for state in states if state.chat])

self._raise_if_failed(chats, on_failed)

current_step = PipelineStep(
state="callback",
chats=chats,
Expand Down Expand Up @@ -1917,9 +1918,18 @@ async def _step( # noqa: PLR0915, PLR0912
)

async with contextlib.AsyncExitStack() as exit_stack:
result = map_task(chats)
if inspect.isawaitable(result):
result = await result
try:
result = map_task(chats)
if inspect.isawaitable(result):
result = await result
except Exception as e: # noqa: BLE001
# If the map raised an exception, assign it to all the chats
for chat in chats:
chat.error = e
chat.failed = True

self._raise_if_failed(chats, on_failed)
continue

if isinstance(result, contextlib.AbstractAsyncContextManager):
result = await exit_stack.enter_async_context(result)
Expand Down
56 changes: 56 additions & 0 deletions tests/test_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,59 @@ async def watch_function(chats: list[Chat]) -> None:
# Watch should be called at least once
assert len(watch_calls) >= 1
assert all(calls >= 1 for calls in watch_calls)


@pytest.mark.asyncio
async def test_map_callback_exception_handling() -> None:
"""Test that exceptions in map callback functions are properly caught and assigned to chat.error and chat.failed."""

generator = FixedGenerator(model="fixed", text="Response", params=GenerateParams())

async def failing_map_callback(chats: list[Chat]) -> list[Chat]:
# Simulate an exception in the map callback
raise RuntimeError("Map callback failure")

pipeline = generator.chat("test").map(failing_map_callback)

# Test with default on_failed behavior (should raise)
with pytest.raises(RuntimeError):
await pipeline.run()

# Should still raise as RuntimeError is not in the default catch list
with pytest.raises(RuntimeError):
await pipeline.run(on_failed="include")

# Should capture now
chat = await pipeline.catch(RuntimeError, on_failed="include").run()

assert chat.failed is True
assert isinstance(chat.error, RuntimeError)
assert str(chat.error) == "Map callback failure"


@pytest.mark.asyncio
async def test_then_callback_exception_handling() -> None:
"""Test that exceptions in then callback functions are properly caught and assigned to chat.error and chat.failed."""

generator = FixedGenerator(model="fixed", text="Response", params=GenerateParams())

async def failing_then_callback(chat: Chat) -> PipelineStepContextManager:
# Simulate an exception in the then callback
raise RuntimeError("Then callback failure")

pipeline = generator.chat("test").then(failing_then_callback)

# Test with default on_failed behavior (should raise)
with pytest.raises(RuntimeError):
await pipeline.run()

# Should still raise as RuntimeError is not in the default catch list
with pytest.raises(RuntimeError):
await pipeline.run(on_failed="include")

# Should capture now
chat = await pipeline.catch(RuntimeError, on_failed="include").run()

assert chat.failed is True
assert isinstance(chat.error, RuntimeError)
assert str(chat.error) == "Then callback failure"
Loading