Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 137 additions & 106 deletions core/src/main/java/com/google/adk/runner/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,16 @@ public PluginManager pluginManager() {
}

/**
* Appends a new user message to the session history.
* Appends a new user message to the session history with optional state delta.
*
* @throws IllegalArgumentException if message has no parts.
*/
private Single<Event> appendNewMessageToSession(
Session session,
Content newMessage,
InvocationContext invocationContext,
boolean saveInputBlobsAsArtifacts) {
boolean saveInputBlobsAsArtifacts,
@Nullable Map<String, Object> stateDelta) {
if (newMessage.parts().isEmpty()) {
throw new IllegalArgumentException("No parts in the new_message.");
}
Expand Down Expand Up @@ -169,14 +170,20 @@ private Single<Event> appendNewMessageToSession(
}
}
// Appends only. We do not yield the event because it's not from the model.
Event event =
Event.Builder eventBuilder =
Event.builder()
.id(Event.generateEventId())
.invocationId(invocationContext.invocationId())
.author("user")
.content(Optional.of(newMessage))
.build();
return this.sessionService.appendEvent(session, event);
.content(Optional.of(newMessage));

// Add state delta if provided
if (stateDelta != null && !stateDelta.isEmpty()) {
eventBuilder.actions(
EventActions.builder().stateDelta(new ConcurrentHashMap<>(stateDelta)).build());
}

return this.sessionService.appendEvent(session, eventBuilder.build());
}

/** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */
Expand Down Expand Up @@ -240,80 +247,104 @@ public Flowable<Event> runAsync(
Span span = Telemetry.getTracer().spanBuilder("invocation").startSpan();
try (Scope unusedScope = span.makeCurrent()) {
BaseAgent rootAgent = this.agent;
InvocationContext context =
newInvocationContext(
String invocationId = InvocationContext.newInvocationContextId();

// Create initial context
InvocationContext initialContext =
newInvocationContextWithId(
session,
Optional.of(newMessage),
/* liveRequestQueue= */ Optional.empty(),
runConfig);

// Emit state delta event if provided, using the same invocation ID
Single<Session> sessionSingle =
(stateDelta != null && !stateDelta.isEmpty())
? emitStateDeltaEvent(session, stateDelta, context.invocationId())
: Single.just(session);

Maybe<Event> beforeRunEvent =
this.pluginManager
.runBeforeRunCallback(context)
.map(
content ->
Event.builder()
.id(Event.generateEventId())
.invocationId(context.invocationId())
.author("model")
.content(Optional.of(content))
.build());

Flowable<Event> agentEvents =
sessionSingle.flatMapPublisher(
updatedSession ->
Flowable.defer(
() ->
this.pluginManager
.runOnUserMessageCallback(context, newMessage)
.switchIfEmpty(Single.just(newMessage))
.flatMap(
content ->
(content != null)
? appendNewMessageToSession(
runConfig,
invocationId);

Flowable<Event> events =
Flowable.defer(
() ->
this.pluginManager
.runOnUserMessageCallback(initialContext, newMessage)
.switchIfEmpty(Single.just(newMessage))
.flatMap(
content ->
(content != null)
? appendNewMessageToSession(
session,
content,
initialContext,
runConfig.saveInputBlobsAsArtifacts(),
stateDelta)
: Single.just(null))
.flatMapPublisher(
event -> {
// Get the updated session after the message and state delta are applied
return this.sessionService
.getSession(
session.appName(),
session.userId(),
session.id(),
Optional.empty())
.flatMapPublisher(
updatedSession -> {
// Create context with updated session for beforeRunCallback
InvocationContext contextWithUpdatedSession =
newInvocationContextWithId(
updatedSession,
content,
context,
runConfig.saveInputBlobsAsArtifacts())
: Single.just(null))
.flatMapPublisher(
event -> {
InvocationContext contextWithNewMessage =
newInvocationContext(
updatedSession,
event.content(),
Optional.empty(),
runConfig);
contextWithNewMessage.agent(
this.findAgentToRun(updatedSession, rootAgent));
return contextWithNewMessage
.agent()
.runAsync(contextWithNewMessage)
.flatMap(
agentEvent ->
this.sessionService
.appendEvent(updatedSession, agentEvent)
.flatMap(
registeredEvent ->
contextWithNewMessage
.pluginManager()
.runOnEventCallback(
contextWithNewMessage,
registeredEvent)
.defaultIfEmpty(registeredEvent))
.toFlowable());
})));

return beforeRunEvent
.toFlowable()
.switchIfEmpty(agentEvents)
.concatWith(Completable.defer(() -> pluginManager.runAfterRunCallback(context)))
event.content(),
Optional.empty(),
runConfig,
invocationId);
contextWithUpdatedSession.agent(
this.findAgentToRun(updatedSession, rootAgent));

// Call beforeRunCallback with updated session
Maybe<Event> beforeRunEvent =
this.pluginManager
.runBeforeRunCallback(contextWithUpdatedSession)
.map(
content ->
Event.builder()
.id(Event.generateEventId())
.invocationId(
contextWithUpdatedSession
.invocationId())
.author("model")
.content(Optional.of(content))
.build());

// Agent execution
Flowable<Event> agentEvents =
contextWithUpdatedSession
.agent()
.runAsync(contextWithUpdatedSession)
.flatMap(
agentEvent ->
this.sessionService
.appendEvent(updatedSession, agentEvent)
.flatMap(
registeredEvent ->
contextWithUpdatedSession
.pluginManager()
.runOnEventCallback(
contextWithUpdatedSession,
registeredEvent)
.defaultIfEmpty(
registeredEvent))
.toFlowable());

// If beforeRunCallback returns content, emit it and skip
// agent
return beforeRunEvent
.toFlowable()
.switchIfEmpty(agentEvents)
.concatWith(
Completable.defer(
() ->
pluginManager.runAfterRunCallback(
contextWithUpdatedSession)));
});
}));

return events
.doOnError(
throwable -> {
span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution");
Expand All @@ -328,36 +359,6 @@ public Flowable<Event> runAsync(
}
}

/**
* Emits a state update event and returns the updated session.
*
* @param session The session to update.
* @param stateDelta The state delta to apply.
* @param invocationId The invocation ID to use for the state delta event.
* @return Single emitting the updated session after applying the state delta.
*/
private Single<Session> emitStateDeltaEvent(
Session session, Map<String, Object> stateDelta, String invocationId) {
ConcurrentHashMap<String, Object> deltaMap = new ConcurrentHashMap<>(stateDelta);

Event stateEvent =
Event.builder()
.id(Event.generateEventId())
.invocationId(invocationId)
.author("user")
.actions(EventActions.builder().stateDelta(deltaMap).build())
.timestamp(System.currentTimeMillis())
.build();

return this.sessionService
.appendEvent(session, stateEvent)
.flatMap(
event ->
this.sessionService
.getSession(session.appName(), session.userId(), session.id(), Optional.empty())
.switchIfEmpty(Single.error(new IllegalStateException("Session not found"))));
}

/**
* Creates an {@link InvocationContext} for a live (streaming) run.
*
Expand Down Expand Up @@ -414,6 +415,36 @@ private InvocationContext newInvocationContext(
return invocationContext;
}

/**
* Creates a new InvocationContext with a specific invocation ID.
*
* @return a new {@link InvocationContext} with the specified invocation ID.
*/
private InvocationContext newInvocationContextWithId(
Session session,
Optional<Content> newMessage,
Optional<LiveRequestQueue> liveRequestQueue,
RunConfig runConfig,
String invocationId) {
BaseAgent rootAgent = this.agent;
InvocationContext invocationContext =
new InvocationContext(
this.sessionService,
this.artifactService,
this.memoryService,
this.pluginManager,
liveRequestQueue,
/* branch= */ Optional.empty(),
invocationId,
rootAgent,
session,
newMessage,
runConfig,
/* endInvocation= */ false);
invocationContext.agent(this.findAgentToRun(session, rootAgent));
return invocationContext;
}

/**
* Runs the agent in live mode, appending generated events to the session.
*
Expand Down
Loading