From 3896ff7b452ff475511db550fce43a8be7fedfa7 Mon Sep 17 00:00:00 2001 From: kimizuka Date: Thu, 28 May 2026 10:51:11 +0900 Subject: [PATCH] runtime: populate ModelID in after_llm_call hook payload The ModelID field on hooks.Input is documented as populated for after_llm_call (pkg/hooks/types.go:177-186), but executeAfterLLMCallHooks never sets it, so hook handlers reading model_id always observed an empty string. Thread the model identifier into the dispatch so the payload matches the documented contract. The harness path already had modelID in scope; the loop path uses modelID.String() from the turn's resolved modelsdev.ID. Add a regression test that captures the after_llm_call Input via a builtin hook and asserts ModelID equals the provider's canonical id. Signed-off-by: kimizuka --- pkg/runtime/after_llm_call_test.go | 76 ++++++++++++++++++++++++++++++ pkg/runtime/harness.go | 2 +- pkg/runtime/hooks.go | 3 +- pkg/runtime/loop.go | 2 +- 4 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 pkg/runtime/after_llm_call_test.go diff --git a/pkg/runtime/after_llm_call_test.go b/pkg/runtime/after_llm_call_test.go new file mode 100644 index 000000000..2f0f519d8 --- /dev/null +++ b/pkg/runtime/after_llm_call_test.go @@ -0,0 +1,76 @@ +package runtime + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/hooks" + "github.com/docker/docker-agent/pkg/session" + "github.com/docker/docker-agent/pkg/team" +) + +// TestAfterLLMCallHook_PopulatesModelID is a regression test for the +// doc/impl mismatch where [hooks.Input.ModelID] is documented as +// populated for after_llm_call but executeAfterLLMCallHooks never +// actually set it — handlers reading model_id always saw an empty +// string. A single successful turn must dispatch after_llm_call with +// ModelID equal to the provider's canonical "/" id. +func TestAfterLLMCallHook_PopulatesModelID(t *testing.T) { + t.Parallel() + + const ( + hookName = "test-after-llm-model-id" + modelID = "test/mock-model" + ) + + var captured atomic.Pointer[hooks.Input] + + stream := newStreamBuilder(). + AddContent("ok"). + AddStopWithUsage(1, 1). + Build() + prov := &mockProvider{id: modelID, stream: stream} + + root := agent.New("root", "test agent", + agent.WithModel(prov), + agent.WithHooks(&latest.HooksConfig{ + AfterLLMCall: []latest.HookDefinition{ + {Type: "builtin", Command: hookName}, + }, + }), + ) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, + WithSessionCompaction(false), + WithModelStore(mockModelStore{}), + ) + require.NoError(t, err) + + require.NoError(t, rt.hooksRegistry.RegisterBuiltin( + hookName, + func(_ context.Context, in *hooks.Input, _ []string) (*hooks.Output, error) { + snap := *in + captured.Store(&snap) + return nil, nil + }, + )) + + sess := session.New(session.WithUserMessage("hi")) + sess.Title = "Unit Test" + + for range rt.RunStream(t.Context(), sess) { + } + + got := captured.Load() + require.NotNil(t, got, "after_llm_call hook must fire on a successful turn") + assert.Equal(t, modelID, got.ModelID, + "after_llm_call payload must include the canonical model id; "+ + "see pkg/hooks/types.go:177-186 for the documented contract") +} diff --git a/pkg/runtime/harness.go b/pkg/runtime/harness.go index 431afc3fa..c48b380f8 100644 --- a/pkg/runtime/harness.go +++ b/pkg/runtime/harness.go @@ -189,7 +189,7 @@ func (r *LocalRuntime) runHarnessAgent(ctx context.Context, sess *session.Sessio content = strings.TrimSpace(finalResult) } - r.executeAfterLLMCallHooks(ctx, sess, a, content) + r.executeAfterLLMCallHooks(ctx, sess, a, modelID, content) r.recordHarnessAssistantMessage(sess, a, content, modelID, usage, cost, events) r.executeStopHooks(ctx, sess, a, content, events) diff --git a/pkg/runtime/hooks.go b/pkg/runtime/hooks.go index 6ca312f77..cca582d90 100644 --- a/pkg/runtime/hooks.go +++ b/pkg/runtime/hooks.go @@ -443,10 +443,11 @@ func (r *LocalRuntime) executeBeforeLLMCallHooks( // stop_response (matching the stop event), so handlers can reuse the // same parsing logic. Failed model calls fire on_error instead and // skip this event. -func (r *LocalRuntime) executeAfterLLMCallHooks(ctx context.Context, sess *session.Session, a *agent.Agent, responseContent string) { +func (r *LocalRuntime) executeAfterLLMCallHooks(ctx context.Context, sess *session.Session, a *agent.Agent, modelID, responseContent string) { r.dispatchHook(ctx, a, hooks.EventAfterLLMCall, &hooks.Input{ SessionID: sess.ID, AgentName: a.Name(), + ModelID: modelID, StopResponse: responseContent, LastUserMessage: sess.GetLastUserMessageContent(), }, nil) diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index aaa1ba6cd..562b61345 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -569,7 +569,7 @@ func (r *LocalRuntime) runTurn( // fire on_error above. The assistant text content is passed // via stop_response, matching the stop event's payload, so // handlers can reuse the same parsing. - r.executeAfterLLMCallHooks(ctx, sess, a, res.Content) + r.executeAfterLLMCallHooks(ctx, sess, a, modelID.String(), res.Content) if usedModel != nil && usedModel.ID() != model.ID() { slog.InfoContext(ctx, "Used fallback model", "agent", a.Name(), "primary", model.ID().String(), "used", usedModel.ID().String())