From c175fe20d258a5ff3641f3723a5a601c5aaf7840 Mon Sep 17 00:00:00 2001 From: Wei Sun Date: Mon, 13 Oct 2025 22:02:28 -0700 Subject: [PATCH] fix: `deltaState` should be appended with `newMessage` event and `beforeRunCallback` should be called after that PiperOrigin-RevId: 818993721 --- .../java/com/google/adk/runner/Runner.java | 243 ++++++++++-------- .../com/google/adk/runner/RunnerTest.java | 113 +++++++- 2 files changed, 240 insertions(+), 116 deletions(-) diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 4acfeb121..b96a5605c 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -132,7 +132,7 @@ public PluginManager pluginManager() { } /** - * Appends a new user message to the session history. + * Appends a new user message to the session history with optional state delta. * * @throws IllegalArgumentException if message has no parts. */ @@ -140,7 +140,8 @@ private Single appendNewMessageToSession( Session session, Content newMessage, InvocationContext invocationContext, - boolean saveInputBlobsAsArtifacts) { + boolean saveInputBlobsAsArtifacts, + @Nullable Map stateDelta) { if (newMessage.parts().isEmpty()) { throw new IllegalArgumentException("No parts in the new_message."); } @@ -169,14 +170,20 @@ private Single appendNewMessageToSession( } } // Appends only. We do not yield the event because it's not from the model. - Event event = + Event.Builder eventBuilder = Event.builder() .id(Event.generateEventId()) .invocationId(invocationContext.invocationId()) .author("user") - .content(Optional.of(newMessage)) - .build(); - return this.sessionService.appendEvent(session, event); + .content(Optional.of(newMessage)); + + // Add state delta if provided + if (stateDelta != null && !stateDelta.isEmpty()) { + eventBuilder.actions( + EventActions.builder().stateDelta(new ConcurrentHashMap<>(stateDelta)).build()); + } + + return this.sessionService.appendEvent(session, eventBuilder.build()); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -240,80 +247,104 @@ public Flowable runAsync( Span span = Telemetry.getTracer().spanBuilder("invocation").startSpan(); try (Scope unusedScope = span.makeCurrent()) { BaseAgent rootAgent = this.agent; - InvocationContext context = - newInvocationContext( + String invocationId = InvocationContext.newInvocationContextId(); + + // Create initial context + InvocationContext initialContext = + newInvocationContextWithId( session, Optional.of(newMessage), /* liveRequestQueue= */ Optional.empty(), - runConfig); - - // Emit state delta event if provided, using the same invocation ID - Single sessionSingle = - (stateDelta != null && !stateDelta.isEmpty()) - ? emitStateDeltaEvent(session, stateDelta, context.invocationId()) - : Single.just(session); - - Maybe beforeRunEvent = - this.pluginManager - .runBeforeRunCallback(context) - .map( - content -> - Event.builder() - .id(Event.generateEventId()) - .invocationId(context.invocationId()) - .author("model") - .content(Optional.of(content)) - .build()); - - Flowable agentEvents = - sessionSingle.flatMapPublisher( - updatedSession -> - Flowable.defer( - () -> - this.pluginManager - .runOnUserMessageCallback(context, newMessage) - .switchIfEmpty(Single.just(newMessage)) - .flatMap( - content -> - (content != null) - ? appendNewMessageToSession( + runConfig, + invocationId); + + Flowable events = + Flowable.defer( + () -> + this.pluginManager + .runOnUserMessageCallback(initialContext, newMessage) + .switchIfEmpty(Single.just(newMessage)) + .flatMap( + content -> + (content != null) + ? appendNewMessageToSession( + session, + content, + initialContext, + runConfig.saveInputBlobsAsArtifacts(), + stateDelta) + : Single.just(null)) + .flatMapPublisher( + event -> { + // Get the updated session after the message and state delta are applied + return this.sessionService + .getSession( + session.appName(), + session.userId(), + session.id(), + Optional.empty()) + .flatMapPublisher( + updatedSession -> { + // Create context with updated session for beforeRunCallback + InvocationContext contextWithUpdatedSession = + newInvocationContextWithId( updatedSession, - content, - context, - runConfig.saveInputBlobsAsArtifacts()) - : Single.just(null)) - .flatMapPublisher( - event -> { - InvocationContext contextWithNewMessage = - newInvocationContext( - updatedSession, - event.content(), - Optional.empty(), - runConfig); - contextWithNewMessage.agent( - this.findAgentToRun(updatedSession, rootAgent)); - return contextWithNewMessage - .agent() - .runAsync(contextWithNewMessage) - .flatMap( - agentEvent -> - this.sessionService - .appendEvent(updatedSession, agentEvent) - .flatMap( - registeredEvent -> - contextWithNewMessage - .pluginManager() - .runOnEventCallback( - contextWithNewMessage, - registeredEvent) - .defaultIfEmpty(registeredEvent)) - .toFlowable()); - }))); - - return beforeRunEvent - .toFlowable() - .switchIfEmpty(agentEvents) - .concatWith(Completable.defer(() -> pluginManager.runAfterRunCallback(context))) + event.content(), + Optional.empty(), + runConfig, + invocationId); + contextWithUpdatedSession.agent( + this.findAgentToRun(updatedSession, rootAgent)); + + // Call beforeRunCallback with updated session + Maybe beforeRunEvent = + this.pluginManager + .runBeforeRunCallback(contextWithUpdatedSession) + .map( + content -> + Event.builder() + .id(Event.generateEventId()) + .invocationId( + contextWithUpdatedSession + .invocationId()) + .author("model") + .content(Optional.of(content)) + .build()); + + // Agent execution + Flowable agentEvents = + contextWithUpdatedSession + .agent() + .runAsync(contextWithUpdatedSession) + .flatMap( + agentEvent -> + this.sessionService + .appendEvent(updatedSession, agentEvent) + .flatMap( + registeredEvent -> + contextWithUpdatedSession + .pluginManager() + .runOnEventCallback( + contextWithUpdatedSession, + registeredEvent) + .defaultIfEmpty( + registeredEvent)) + .toFlowable()); + + // If beforeRunCallback returns content, emit it and skip + // agent + return beforeRunEvent + .toFlowable() + .switchIfEmpty(agentEvents) + .concatWith( + Completable.defer( + () -> + pluginManager.runAfterRunCallback( + contextWithUpdatedSession))); + }); + })); + + return events .doOnError( throwable -> { span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution"); @@ -328,36 +359,6 @@ public Flowable runAsync( } } - /** - * Emits a state update event and returns the updated session. - * - * @param session The session to update. - * @param stateDelta The state delta to apply. - * @param invocationId The invocation ID to use for the state delta event. - * @return Single emitting the updated session after applying the state delta. - */ - private Single emitStateDeltaEvent( - Session session, Map stateDelta, String invocationId) { - ConcurrentHashMap deltaMap = new ConcurrentHashMap<>(stateDelta); - - Event stateEvent = - Event.builder() - .id(Event.generateEventId()) - .invocationId(invocationId) - .author("user") - .actions(EventActions.builder().stateDelta(deltaMap).build()) - .timestamp(System.currentTimeMillis()) - .build(); - - return this.sessionService - .appendEvent(session, stateEvent) - .flatMap( - event -> - this.sessionService - .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) - .switchIfEmpty(Single.error(new IllegalStateException("Session not found")))); - } - /** * Creates an {@link InvocationContext} for a live (streaming) run. * @@ -414,6 +415,36 @@ private InvocationContext newInvocationContext( return invocationContext; } + /** + * Creates a new InvocationContext with a specific invocation ID. + * + * @return a new {@link InvocationContext} with the specified invocation ID. + */ + private InvocationContext newInvocationContextWithId( + Session session, + Optional newMessage, + Optional liveRequestQueue, + RunConfig runConfig, + String invocationId) { + BaseAgent rootAgent = this.agent; + InvocationContext invocationContext = + new InvocationContext( + this.sessionService, + this.artifactService, + this.memoryService, + this.pluginManager, + liveRequestQueue, + /* branch= */ Optional.empty(), + invocationId, + rootAgent, + session, + newMessage, + runConfig, + /* endInvocation= */ false); + invocationContext.agent(this.findAgentToRun(session, rootAgent)); + return invocationContext; + } + /** * Runs the agent in live mode, appending generated events to the session. * diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 2cb5d3d9e..48364da41 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -28,6 +28,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.events.Event; @@ -450,7 +451,7 @@ public void runAsync_withNullStateDelta_doesNotModifySession() { } @Test - public void runAsync_withStateDelta_appendsStateEventToHistory() { + public void runAsync_withStateDelta_attachesStateToUserMessageEvent() { var unused = runner .runAsync( @@ -468,15 +469,39 @@ public void runAsync_withStateDelta_appendsStateEventToHistory() { .getSession("test", "user", session.id(), Optional.empty()) .blockingGet(); - assertThat( - finalSession.events().stream() - .anyMatch( - e -> - e.author().equals("user") - && e.actions() != null - && e.actions().stateDelta() != null - && !e.actions().stateDelta().isEmpty())) - .isTrue(); + // Verify state delta is attached to the user message event, not a separate event + Event userEvent = + finalSession.events().stream() + .filter( + e -> + e.author().equals("user") + && e.content().isPresent() + && e.content().get().parts().get().get(0).text().isPresent() + && e.content() + .get() + .parts() + .get() + .get(0) + .text() + .get() + .equals("test message")) + .findFirst() + .orElseThrow(); + + assertThat(userEvent.actions()).isNotNull(); + assertThat(userEvent.actions().stateDelta()).containsEntry("testKey", "testValue"); + + // Verify there is no separate state-only event + long stateOnlyEvents = + finalSession.events().stream() + .filter( + e -> + e.author().equals("user") + && e.content().isEmpty() + && e.actions() != null + && !e.actions().stateDelta().isEmpty()) + .count(); + assertThat(stateOnlyEvents).isEqualTo(0); } @Test @@ -510,6 +535,74 @@ public void runAsync_withStateDelta_mergesWithExistingState() { assertThat(finalSession.state()).containsEntry("new_key", "new_value"); } + @Test + public void beforeRunCallback_seesUserMessageInSession() { + ArgumentCaptor contextCaptor = + ArgumentCaptor.forClass(InvocationContext.class); + when(plugin.beforeRunCallback(contextCaptor.capture())).thenReturn(Maybe.empty()); + + var unused = + runner + .runAsync("user", session.id(), createContent("user message for callback")) + .toList() + .blockingGet(); + + // Verify beforeRunCallback was called + verify(plugin).beforeRunCallback(any()); + + // Verify the context passed to beforeRunCallback contains the session with user message + InvocationContext capturedContext = contextCaptor.getValue(); + Session sessionInCallback = capturedContext.session(); + + // Check that the user message is in the session history + boolean userMessageFound = + sessionInCallback.events().stream() + .anyMatch( + e -> + e.author().equals("user") + && e.content().isPresent() + && e.content().get().parts().get().get(0).text().isPresent() + && e.content() + .get() + .parts() + .get() + .get(0) + .text() + .get() + .contains("user message for callback")); + + assertThat(userMessageFound).isTrue(); + } + + @Test + public void beforeRunCallback_withStateDelta_seesMergedState() { + ArgumentCaptor contextCaptor = + ArgumentCaptor.forClass(InvocationContext.class); + when(plugin.beforeRunCallback(contextCaptor.capture())).thenReturn(Maybe.empty()); + + ImmutableMap stateDelta = + ImmutableMap.of("callback_key", "callback_value", "number", 123); + + var unused = + runner + .runAsync( + "user", + session.id(), + createContent("test with state"), + RunConfig.builder().build(), + stateDelta) + .toList() + .blockingGet(); + + // Verify the context passed to beforeRunCallback has the merged state + InvocationContext capturedContext = contextCaptor.getValue(); + Session sessionInCallback = capturedContext.session(); + + // Verify state delta was merged before beforeRunCallback was invoked + assertThat(sessionInCallback.state()).containsEntry("callback_key", "callback_value"); + assertThat(sessionInCallback.state()).containsEntry("number", 123); + } + private Content createContent(String text) { return Content.builder().parts(Part.builder().text(text).build()).build(); }