diff --git a/core/src/main/java/com/google/adk/utils/ModelNameUtils.java b/core/src/main/java/com/google/adk/utils/ModelNameUtils.java index 3a164e562..d39572a79 100644 --- a/core/src/main/java/com/google/adk/utils/ModelNameUtils.java +++ b/core/src/main/java/com/google/adk/utils/ModelNameUtils.java @@ -7,6 +7,8 @@ public class ModelNameUtils { private static final Pattern GEMINI_2_PATTERN = Pattern.compile("^gemini-2\\..*"); private static final Pattern PATH_PATTERN = Pattern.compile("^projects/[^/]+/locations/[^/]+/publishers/[^/]+/models/(.+)$"); + private static final Pattern APIGEE_PATTERN = + Pattern.compile("^apigee/(?:[^/]+/)?(?:[^/]+/)?(.+)$"); public static boolean isGemini2Model(String modelString) { if (modelString == null) { @@ -28,6 +30,10 @@ private static String extractModelName(String modelString) { if (matcher.matches()) { return matcher.group(1); } + Matcher apigeeMatcher = APIGEE_PATTERN.matcher(modelString); + if (apigeeMatcher.matches()) { + return apigeeMatcher.group(1); + } return modelString; } diff --git a/core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java b/core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java new file mode 100644 index 000000000..37853c477 --- /dev/null +++ b/core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java @@ -0,0 +1,72 @@ +package com.google.adk.utils; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ModelNameUtilsTest { + + @Test + public void isGemini2Model_withGemini2Model_returnsTrue() { + assertThat(ModelNameUtils.isGemini2Model("gemini-2.5-flash")).isTrue(); + } + + @Test + public void isGemini2Model_withNonGemini2Model_returnsFalse() { + assertThat(ModelNameUtils.isGemini2Model("gemini-1.5-pro")).isFalse(); + } + + @Test + public void isGemini2Model_withPathBasedGemini2Model_returnsTrue() { + assertThat( + ModelNameUtils.isGemini2Model( + "projects/test-project/locations/us-central1/publishers/google/models/gemini-2.5-flash")) + .isTrue(); + } + + @Test + public void isGemini2Model_withPathBasedNonGemini2Model_returnsFalse() { + assertThat( + ModelNameUtils.isGemini2Model( + "projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro")) + .isFalse(); + } + + @Test + public void isGemini2Model_withApigeeGemini2Model_returnsTrue() { + assertThat(ModelNameUtils.isGemini2Model("apigee/gemini-2.5-flash")).isTrue(); + } + + @Test + public void isGemini2Model_withApigeeV1Gemini2Model_returnsTrue() { + assertThat(ModelNameUtils.isGemini2Model("apigee/v1/gemini-2.5-flash")).isTrue(); + } + + @Test + public void isGemini2Model_withApigeeProviderGemini2Model_returnsTrue() { + assertThat(ModelNameUtils.isGemini2Model("apigee/gemini/gemini-2.5-flash")).isTrue(); + } + + @Test + public void isGemini2Model_withApigeeProviderVertexGemini2Model_returnsTrue() { + assertThat(ModelNameUtils.isGemini2Model("apigee/vertex_ai/gemini-2.5-flash")).isTrue(); + } + + @Test + public void isGemini2Model_withApigeeProviderV1Gemini2Model_returnsTrue() { + assertThat(ModelNameUtils.isGemini2Model("apigee/gemini/v1/gemini-2.5-flash")).isTrue(); + } + + @Test + public void isGemini2Model_withApigeeProviderV1BetaGemini2Model_returnsTrue() { + assertThat(ModelNameUtils.isGemini2Model("apigee/vertex_ai/v1beta/gemini-2.5-flash")).isTrue(); + } + + @Test + public void isGemini2Model_withNullModel_returnsFalse() { + assertThat(ModelNameUtils.isGemini2Model(null)).isFalse(); + } +}