diff --git a/core/src/main/java/com/google/adk/events/Event.java b/core/src/main/java/com/google/adk/events/Event.java index b6141f569..d56a0ab68 100644 --- a/core/src/main/java/com/google/adk/events/Event.java +++ b/core/src/main/java/com/google/adk/events/Event.java @@ -30,6 +30,7 @@ import com.google.genai.types.FinishReason; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.GroundingMetadata; import java.time.Instant; import java.util.List; @@ -54,6 +55,9 @@ public class Event extends JsonBaseModel { private Optional turnComplete = Optional.empty(); private Optional errorCode = Optional.empty(); private Optional errorMessage = Optional.empty(); + private Optional finishReason = Optional.empty(); + private Optional usageMetadata = Optional.empty(); + private Optional avgLogprobs = Optional.empty(); private Optional interrupted = Optional.empty(); private Optional branch = Optional.empty(); private Optional groundingMetadata = Optional.empty(); @@ -153,10 +157,19 @@ public Optional errorCode() { return errorCode; } + @JsonProperty("finishReason") + public Optional finishReason() { + return finishReason; + } + public void setErrorCode(Optional errorCode) { this.errorCode = errorCode; } + public void setFinishReason(Optional finishReason) { + this.finishReason = finishReason; + } + @JsonProperty("errorMessage") public Optional errorMessage() { return errorMessage; @@ -166,6 +179,24 @@ public void setErrorMessage(Optional errorMessage) { this.errorMessage = errorMessage; } + @JsonProperty("usageMetadata") + public Optional usageMetadata() { + return usageMetadata; + } + + public void setUsageMetadata(Optional usageMetadata) { + this.usageMetadata = usageMetadata; + } + + @JsonProperty("avgLogprobs") + public Optional avgLogprobs() { + return avgLogprobs; + } + + public void setAvgLogprobs(Optional avgLogprobs) { + this.avgLogprobs = avgLogprobs; + } + @JsonProperty("interrupted") public Optional interrupted() { return interrupted; @@ -299,6 +330,9 @@ public static class Builder { private Optional turnComplete = Optional.empty(); private Optional errorCode = Optional.empty(); private Optional errorMessage = Optional.empty(); + private Optional finishReason = Optional.empty(); + private Optional usageMetadata = Optional.empty(); + private Optional avgLogprobs = Optional.empty(); private Optional interrupted = Optional.empty(); private Optional branch = Optional.empty(); private Optional groundingMetadata = Optional.empty(); @@ -419,6 +453,45 @@ public Builder errorMessage(Optional value) { return this; } + @CanIgnoreReturnValue + @JsonProperty("finishReason") + public Builder finishReason(@Nullable FinishReason value) { + this.finishReason = Optional.ofNullable(value); + return this; + } + + @CanIgnoreReturnValue + public Builder finishReason(Optional value) { + this.finishReason = value; + return this; + } + + @CanIgnoreReturnValue + @JsonProperty("usageMetadata") + public Builder usageMetadata(@Nullable GenerateContentResponseUsageMetadata value) { + this.usageMetadata = Optional.ofNullable(value); + return this; + } + + @CanIgnoreReturnValue + public Builder usageMetadata(Optional value) { + this.usageMetadata = value; + return this; + } + + @CanIgnoreReturnValue + @JsonProperty("avgLogprobs") + public Builder avgLogprobs(@Nullable Double value) { + this.avgLogprobs = Optional.ofNullable(value); + return this; + } + + @CanIgnoreReturnValue + public Builder avgLogprobs(Optional value) { + this.avgLogprobs = value; + return this; + } + @CanIgnoreReturnValue @JsonProperty("interrupted") public Builder interrupted(@Nullable Boolean value) { @@ -496,10 +569,12 @@ public Event build() { event.setTurnComplete(turnComplete); event.setErrorCode(errorCode); event.setErrorMessage(errorMessage); + event.setFinishReason(finishReason); + event.setUsageMetadata(usageMetadata); + event.setAvgLogprobs(avgLogprobs); event.setInterrupted(interrupted); event.branch(branch); event.setGroundingMetadata(groundingMetadata); - event.setActions(actions().orElse(EventActions.builder().build())); event.setTimestamp(timestamp().orElse(Instant.now().toEpochMilli())); return event; @@ -529,6 +604,9 @@ public Builder toBuilder() { .turnComplete(this.turnComplete) .errorCode(this.errorCode) .errorMessage(this.errorMessage) + .finishReason(this.finishReason) + .usageMetadata(this.usageMetadata) + .avgLogprobs(this.avgLogprobs) .interrupted(this.interrupted) .branch(this.branch) .groundingMetadata(this.groundingMetadata); @@ -557,6 +635,9 @@ public boolean equals(Object obj) { && Objects.equals(turnComplete, other.turnComplete) && Objects.equals(errorCode, other.errorCode) && Objects.equals(errorMessage, other.errorMessage) + && Objects.equals(finishReason, other.finishReason) + && Objects.equals(usageMetadata, other.usageMetadata) + && Objects.equals(avgLogprobs, other.avgLogprobs) && Objects.equals(interrupted, other.interrupted) && Objects.equals(branch, other.branch) && Objects.equals(groundingMetadata, other.groundingMetadata); @@ -580,6 +661,9 @@ public int hashCode() { turnComplete, errorCode, errorMessage, + finishReason, + usageMetadata, + avgLogprobs, interrupted, branch, groundingMetadata, diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 34114d004..5c7dad328 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -633,7 +633,10 @@ private Event buildModelResponseEvent( .errorMessage(llmResponse.errorMessage()) .interrupted(llmResponse.interrupted()) .turnComplete(llmResponse.turnComplete()) - .groundingMetadata(llmResponse.groundingMetadata()); + .groundingMetadata(llmResponse.groundingMetadata()) + .avgLogprobs(llmResponse.avgLogprobs()) + .finishReason(llmResponse.finishReason()) + .usageMetadata(llmResponse.usageMetadata()); Event event = eventBuilder.build(); diff --git a/core/src/main/java/com/google/adk/models/LlmResponse.java b/core/src/main/java/com/google/adk/models/LlmResponse.java index 0d748d1f3..0a8c2806d 100644 --- a/core/src/main/java/com/google/adk/models/LlmResponse.java +++ b/core/src/main/java/com/google/adk/models/LlmResponse.java @@ -79,6 +79,14 @@ public abstract class LlmResponse extends JsonBaseModel { @JsonProperty("errorCode") public abstract Optional errorCode(); + /** Error code if the response is an error. Code varies by model. */ + @JsonProperty("finishReason") + public abstract Optional finishReason(); + + /** Error code if the response is an error. Code varies by model. */ + @JsonProperty("avgLogprobs") + public abstract Optional avgLogprobs(); + /** Error message if the response is an error. */ @JsonProperty("errorMessage") public abstract Optional errorMessage(); @@ -136,6 +144,16 @@ static LlmResponse.Builder jacksonBuilder() { public abstract Builder errorCode(Optional errorCode); + @JsonProperty("finishReason") + public abstract Builder finishReason(@Nullable FinishReason finishReason); + + public abstract Builder finishReason(Optional finishReason); + + @JsonProperty("avgLogprobs") + public abstract Builder avgLogprobs(@Nullable Double avgLogprobs); + + public abstract Builder avgLogprobs(Optional avgLogprobs); + @JsonProperty("errorMessage") public abstract Builder errorMessage(@Nullable String errorMessage); diff --git a/core/src/test/java/com/google/adk/events/EventTest.java b/core/src/test/java/com/google/adk/events/EventTest.java index 358ac8bdf..f443abee5 100644 --- a/core/src/test/java/com/google/adk/events/EventTest.java +++ b/core/src/test/java/com/google/adk/events/EventTest.java @@ -24,6 +24,7 @@ import com.google.genai.types.Content; import com.google.genai.types.FinishReason; import com.google.genai.types.FunctionCall; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import java.time.Instant; import java.util.concurrent.ConcurrentHashMap; @@ -67,6 +68,14 @@ public final class EventTest { .turnComplete(true) .errorCode(new FinishReason("error_code")) .errorMessage("error_message") + .finishReason(new FinishReason("finish_reason")) + .usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build()) + .avgLogprobs(0.5) .interrupted(true) .timestamp(123456789L) .build(); @@ -80,6 +89,15 @@ public void event_builder_works() { assertThat(EVENT.turnComplete().get()).isTrue(); assertThat(EVENT.errorCode()).hasValue(new FinishReason("error_code")); assertThat(EVENT.errorMessage()).hasValue("error_message"); + assertThat(EVENT.finishReason()).hasValue(new FinishReason("finish_reason")); + assertThat(EVENT.usageMetadata()) + .hasValue( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build()); + assertThat(EVENT.avgLogprobs()).hasValue(0.5); assertThat(EVENT.interrupted()).hasValue(true); assertThat(EVENT.timestamp()).isEqualTo(123456789L); assertThat(EVENT.actions()).isEqualTo(EVENT_ACTIONS); diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index 700a9c0c4..37a043930 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -37,7 +37,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; +import com.google.genai.types.FinishReason; import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Single; @@ -63,7 +65,44 @@ public void run_singleTextResponse_returnsSingleEvent() { List events = baseLlmFlow.run(invocationContext).toList().blockingGet(); assertThat(events).hasSize(1); - assertThat(getOnlyElement(events).content()).hasValue(content); + Event event = getOnlyElement(events); + assertThat(event.content()).hasValue(content); + assertThat(event.avgLogprobs()).isEmpty(); + assertThat(event.finishReason()).isEmpty(); + assertThat(event.usageMetadata()).isEmpty(); + } + + @Test + public void run_singleTextResponse_withMetadata_returnsSingleEventWithMetadata() { + Content content = Content.fromParts(Part.fromText("LLM response")); + LlmResponse llmResponse = + LlmResponse.builder() + .content(content) + .avgLogprobs(-0.123) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .build()) + .build(); + TestLlm testLlm = createTestLlm(llmResponse); + InvocationContext invocationContext = createInvocationContext(createTestAgent(testLlm)); + BaseLlmFlow baseLlmFlow = createBaseLlmFlowWithoutProcessors(); + + List events = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(events).hasSize(1); + Event event = getOnlyElement(events); + assertThat(event.content()).hasValue(content); + assertThat(event.avgLogprobs()).hasValue(-0.123); + assertThat(event.finishReason()).hasValue(new FinishReason(FinishReason.Known.STOP)); + assertThat(event.usageMetadata()) + .hasValue( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .build()); } @Test