diff --git a/core/src/main/java/com/google/adk/tools/FunctionTool.java b/core/src/main/java/com/google/adk/tools/FunctionTool.java index 54050dc9b..8cbc5d6f8 100644 --- a/core/src/main/java/com/google/adk/tools/FunctionTool.java +++ b/core/src/main/java/com/google/adk/tools/FunctionTool.java @@ -185,6 +185,17 @@ public Method func() { return func; } + /** Returns the underlying function's {@link Object} instance if present. */ + @Nullable + Object instance() { + return instance; + } + + /** Returns whether the function requires confirmation */ + boolean requireConfirmation() { + return requireConfirmation; + } + /** Returns true if the wrapped function returns a Flowable and can be used for streaming. */ public boolean isStreaming() { Type returnType = func.getGenericReturnType(); diff --git a/core/src/main/java/com/google/adk/tools/LongRunningFunctionTool.java b/core/src/main/java/com/google/adk/tools/LongRunningFunctionTool.java index 0cd02202d..68b7c242b 100644 --- a/core/src/main/java/com/google/adk/tools/LongRunningFunctionTool.java +++ b/core/src/main/java/com/google/adk/tools/LongRunningFunctionTool.java @@ -16,6 +16,8 @@ package com.google.adk.tools; +import com.fasterxml.jackson.core.type.TypeReference; +import com.google.adk.utils.ComponentRegistry; import java.lang.reflect.Method; import javax.annotation.Nullable; @@ -70,6 +72,11 @@ public static LongRunningFunctionTool create( return new LongRunningFunctionTool(instance, method, requireConfirmation); } + /** Creates a LongRunningFunctionTool from a FunctionTool. */ + public static LongRunningFunctionTool create(FunctionTool tool) { + return create(tool.instance(), tool.func(), tool.requireConfirmation()); + } + private LongRunningFunctionTool(Method func, boolean requireConfirmation) { super(null, func, /* isLongRunning= */ true, requireConfirmation); } @@ -78,4 +85,25 @@ private LongRunningFunctionTool( @Nullable Object instance, Method func, boolean requireConfirmation) { super(instance, func, /* isLongRunning= */ true, requireConfirmation); } + + public static LongRunningFunctionTool fromConfig(ToolArgsConfig config, String configAbsPath) { + String funcName = + config + .getOrEmpty("func", new TypeReference() {}) + .orElseThrow( + () -> + new IllegalArgumentException("\"func\" argument should be name of a function")); + + FunctionTool funcTool = + ComponentRegistry.getInstance() + .get(funcName, FunctionTool.class) + .orElseThrow( + () -> + new IllegalArgumentException( + String.format( + "failed to find FunctionTool \"%s\" in the ComponentRegistry", + funcName))); + + return create(funcTool); + } } diff --git a/core/src/main/java/com/google/adk/utils/ComponentRegistry.java b/core/src/main/java/com/google/adk/utils/ComponentRegistry.java index a526a0043..2f4e43561 100644 --- a/core/src/main/java/com/google/adk/utils/ComponentRegistry.java +++ b/core/src/main/java/com/google/adk/utils/ComponentRegistry.java @@ -32,6 +32,7 @@ import com.google.adk.tools.ExitLoopTool; import com.google.adk.tools.GoogleSearchTool; import com.google.adk.tools.LoadArtifactsTool; +import com.google.adk.tools.LongRunningFunctionTool; import com.google.adk.tools.UrlContextTool; import com.google.adk.tools.mcp.McpToolset; import java.util.Map; @@ -110,6 +111,7 @@ private void initializePreWiredEntries() { registerAdkToolInstance("url_context", UrlContextTool.INSTANCE); registerAdkToolClass(AgentTool.class); + registerAdkToolClass(LongRunningFunctionTool.class); registerAdkToolsetClass(McpToolset.class); // TODO: add all python tools that also exist in Java. diff --git a/core/src/test/java/com/google/adk/tools/LongRunningFunctionToolTest.java b/core/src/test/java/com/google/adk/tools/LongRunningFunctionToolTest.java index 1b19eb7c3..7793b026c 100644 --- a/core/src/test/java/com/google/adk/tools/LongRunningFunctionToolTest.java +++ b/core/src/test/java/com/google/adk/tools/LongRunningFunctionToolTest.java @@ -1,6 +1,7 @@ package com.google.adk.tools; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.google.adk.agents.LlmAgent; import com.google.adk.artifacts.InMemoryArtifactService; @@ -13,6 +14,8 @@ import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; +import com.google.adk.tools.BaseTool.ToolArgsConfig; +import com.google.adk.utils.ComponentRegistry; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.Keep; @@ -92,6 +95,51 @@ public void asyncFunction_handlesPendingAndResults() throws Exception { assertThat(TestFunctions.functionCalledCount.get()).isEqualTo(1); } + @Test + public void fromConfig_validConfig_createsTool() throws Exception { + // Register a FunctionTool to be retrieved by fromConfig + FunctionTool testTool = + FunctionTool.create( + null, TestFunctions.class.getMethod("increaseByOne", int.class, ToolContext.class)); + ComponentRegistry.getInstance().register("testFunc", testTool); + + ToolArgsConfig config = new ToolArgsConfig(); + config.put("func", "testFunc"); + LongRunningFunctionTool longRunningTool = + LongRunningFunctionTool.fromConfig(config, "testPath"); + + assertThat(longRunningTool).isNotNull(); + assertThat(longRunningTool.name()).isEqualTo("increase_by_one"); + } + + @Test + public void fromConfig_missingFunc_throwsException() { + ToolArgsConfig config = new ToolArgsConfig(); + assertThrows( + IllegalArgumentException.class, + () -> LongRunningFunctionTool.fromConfig(config, "testPath")); + } + + @Test + public void fromConfig_funcNotString_throwsException() { + ToolArgsConfig config = new ToolArgsConfig(); + config.put("func", 123); + assertThrows( + IllegalArgumentException.class, + () -> LongRunningFunctionTool.fromConfig(config, "testPath")); + } + + @Test + public void fromConfig_funcNotFound_throwsException() { + ToolArgsConfig config = new ToolArgsConfig(); + config.put("func", "nonExistentFunc"); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> LongRunningFunctionTool.fromConfig(config, "testPath")); + assertThat(exception).hasMessageThat().contains("\"nonExistentFunc\""); + } + private static class TestFunctions { static final AtomicInteger functionCalledCount = new AtomicInteger(0);