diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 0ee1baef1..d165d0b77 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -52,7 +52,6 @@ import io.reactivex.rxjava3.observers.DisposableCompletableObserver; import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.Set; @@ -131,9 +130,10 @@ protected Single preprocess( /** * Post-processes the LLM response after receiving it from the LLM. Executes all registered {@link - * ResponseProcessor} instances. Handles function calls if present in the response. + * ResponseProcessor} instances. Emits events for the model response and any subsequent function + * calls. */ - protected Single postprocess( + protected Flowable postprocess( InvocationContext context, Event baseEventForLlmResponse, LlmRequest llmRequest, @@ -154,46 +154,36 @@ protected Single postprocess( .map(ResponseProcessingResult::updatedResponse); } - return currentLlmResponse.flatMap( + return currentLlmResponse.flatMapPublisher( updatedResponse -> { + Flowable processorEvents = Flowable.fromIterable(Iterables.concat(eventIterables)); + if (updatedResponse.content().isEmpty() && updatedResponse.errorCode().isEmpty() && !updatedResponse.interrupted().orElse(false) && !updatedResponse.turnComplete().orElse(false)) { - return Single.just( - ResponseProcessingResult.create( - updatedResponse, Iterables.concat(eventIterables), Optional.empty())); + return processorEvents; } Event modelResponseEvent = buildModelResponseEvent(baseEventForLlmResponse, llmRequest, updatedResponse); - eventIterables.add(Collections.singleton(modelResponseEvent)); - Maybe maybeFunctionCallEvent; + Flowable modelEventStream = Flowable.just(modelResponseEvent); + if (modelResponseEvent.functionCalls().isEmpty()) { - maybeFunctionCallEvent = Maybe.empty(); - } else if (context.runConfig().streamingMode() == StreamingMode.BIDI) { + return processorEvents.concatWith(modelEventStream); + } + + Maybe maybeFunctionCallEvent; + if (context.runConfig().streamingMode() == StreamingMode.BIDI) { maybeFunctionCallEvent = Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()); } else { maybeFunctionCallEvent = Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); } - return maybeFunctionCallEvent - .map(Optional::of) - .defaultIfEmpty(Optional.empty()) - .map( - functionCallEventOpt -> { - Optional transferToAgent = Optional.empty(); - if (functionCallEventOpt.isPresent()) { - Event functionCallEvent = functionCallEventOpt.get(); - eventIterables.add(Collections.singleton(functionCallEvent)); - transferToAgent = functionCallEvent.actions().transferToAgent(); - } - Iterable combinedEvents = Iterables.concat(eventIterables); - return ResponseProcessingResult.create( - updatedResponse, combinedEvents, transferToAgent); - }); + + return processorEvents.concatWith(modelEventStream).concatWith(maybeFunctionCallEvent); }); } @@ -374,33 +364,27 @@ private Flowable runOneStep(InvocationContext context) { Flowable restOfFlow = callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate) .concatMap( - llmResponse -> { - Single postResultSingle = - postprocess( - context, - mutableEventTemplate, - llmRequestAfterPreprocess, - llmResponse); - - return postResultSingle - .doOnSuccess( - ignored -> { - String oldId = mutableEventTemplate.id(); - mutableEventTemplate.setId(Event.generateEventId()); - logger.debug( - "Updated mutableEventTemplate ID from {} to {} for next" - + " LlmResponse", - oldId, - mutableEventTemplate.id()); - }) - .toFlowable(); - }) + llmResponse -> + postprocess( + context, + mutableEventTemplate, + llmRequestAfterPreprocess, + llmResponse) + .doFinally( + () -> { + String oldId = mutableEventTemplate.id(); + mutableEventTemplate.setId(Event.generateEventId()); + logger.debug( + "Updated mutableEventTemplate ID from {} to {} for" + + " next LlmResponse", + oldId, + mutableEventTemplate.id()); + })) .concatMap( - postResult -> { - Flowable postProcessedEvents = - Flowable.fromIterable(postResult.events()); - if (postResult.transferToAgent().isPresent()) { - String agentToTransfer = postResult.transferToAgent().get(); + event -> { + Flowable postProcessedEvents = Flowable.just(event); + if (event.actions().transferToAgent().isPresent()) { + String agentToTransfer = event.actions().transferToAgent().get(); logger.debug("Transferring to agent: {}", agentToTransfer); BaseAgent rootAgent = context.agent().rootAgent(); BaseAgent nextAgent = rootAgent.findAgent(agentToTransfer); @@ -569,7 +553,7 @@ public void onError(Throwable e) { Flowable receiveFlow = connection .receive() - .flatMapSingle( + .flatMap( llmResponse -> { Event baseEventForThisLlmResponse = liveEventBuilderTemplate.id(Event.generateEventId()).build(); @@ -580,15 +564,15 @@ public void onError(Throwable e) { llmResponse); }) .flatMap( - postResult -> { - Flowable events = Flowable.fromIterable(postResult.events()); - if (postResult.transferToAgent().isPresent()) { + event -> { + Flowable events = Flowable.just(event); + if (event.actions().transferToAgent().isPresent()) { BaseAgent rootAgent = invocationContext.agent().rootAgent(); BaseAgent nextAgent = - rootAgent.findAgent(postResult.transferToAgent().get()); + rootAgent.findAgent(event.actions().transferToAgent().get()); if (nextAgent == null) { throw new IllegalStateException( - "Agent not found: " + postResult.transferToAgent().get()); + "Agent not found: " + event.actions().transferToAgent().get()); } Flowable nextAgentEvents = nextAgent.runLive(invocationContext);