Skip to content

Commit aae9d2c

Browse files
authored
feat: Adds streaming support for wrapOpenAI (#118)
* Adds support for tracing streams with traceable * Polish * Make stream tracing opt-in * Rework to use a passthrough instead of a proxy * Adds streaming support for wrapOpenAI * Record stream cancellations as errors * Adds streaming example * Allow empty config default, add example * Fix
1 parent 0217521 commit aae9d2c

5 files changed

Lines changed: 693 additions & 126 deletions

File tree

langsmith-java-core/src/main/kotlin/com/langchain/smith/wrappers/openai/WrapOpenAI.kt

Lines changed: 149 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@ import com.langchain.smith.tracing.traceable
1010
import com.openai.client.OpenAIClient
1111
import com.openai.core.ClientOptions
1212
import com.openai.core.RequestOptions
13+
import com.openai.core.http.StreamResponse
14+
import com.openai.helpers.ChatCompletionAccumulator
15+
import com.openai.helpers.ResponseAccumulator
1316
import com.openai.models.chat.completions.ChatCompletion
17+
import com.openai.models.chat.completions.ChatCompletionChunk
1418
import com.openai.models.chat.completions.ChatCompletionCreateParams
1519
import com.openai.models.responses.Response
1620
import com.openai.models.responses.ResponseCreateParams
21+
import com.openai.models.responses.ResponseStreamEvent
1722
import com.openai.services.blocking.ChatService
1823
import com.openai.services.blocking.ResponseService
1924
import com.openai.services.blocking.chat.ChatCompletionService
@@ -23,11 +28,10 @@ import java.util.function.Function
2328
/**
2429
* Wraps an [OpenAIClient] with LangSmith tracing.
2530
*
26-
* The returned client traces `chat().completions().create()` and `responses().create()` calls using
27-
* [traceable], recording inputs and outputs as JSON maps. All other methods — including
28-
* `createStreaming()` — pass through to the delegate unchanged.
29-
*
30-
* // TODO: Trace `createStreaming()` for both chat completions and responses.
31+
* The returned client traces `chat().completions().create()`, `responses().create()`, and their
32+
* streaming counterparts (`createStreaming()`) using [traceable], recording inputs and outputs as
33+
* JSON maps. Streaming calls aggregate chunks using the SDK's built-in accumulators and record the
34+
* final aggregated result when the stream closes.
3135
*
3236
* The [config] is used as the base configuration for all traced calls — the wrapper overrides
3337
* [TraceConfig.name] and [TraceConfig.runType] per operation.
@@ -55,7 +59,8 @@ import java.util.function.Function
5559
* @param config base tracing configuration (client, project, tags, etc.)
5660
* @return a wrapped [OpenAIClient] that traces LLM calls to LangSmith
5761
*/
58-
fun wrapOpenAI(client: OpenAIClient, config: TraceConfig): OpenAIClient =
62+
@JvmOverloads
63+
fun wrapOpenAI(client: OpenAIClient, config: TraceConfig = TraceConfig()): OpenAIClient =
5964
TracedOpenAIClient(client, config)
6065

6166
private class TracedOpenAIClient(
@@ -90,41 +95,42 @@ private class TracedChatService(
9095
}
9196

9297
/**
93-
* Builds the traced `create` functions (1-arg and 2-arg) for an OpenAI service.
98+
* Builds the traced `create` and `createStreaming` functions for an OpenAI service.
9499
*
95100
* Both [TracedChatCompletionService] and [TracedResponseService] follow the same pattern — this
96101
* extracts the shared tracing plumbing so each service only provides the type-specific bits.
97102
*
98103
* @param P the params type (e.g. [ChatCompletionCreateParams])
99104
* @param R the response type (e.g. [ChatCompletion])
100-
* @param config base tracing configuration
101-
* @param paramsToMap serializes params to a map for recording as run inputs
102-
* @param useResponsesApi whether this is the responses API (affects metadata)
103-
* @param createOne calls the underlying service's `create(params)` method
104-
* @param createTwo calls the underlying service's `create(params, requestOptions)` method
105+
* @param C the chunk/event type for streaming (e.g. [ChatCompletionChunk])
105106
*/
106-
private class TracedOpenAIClientCreate<P, R : Any>(
107+
private class TracedOpenAIClientCreate<P, R : Any, C>(
107108
config: TraceConfig,
108-
paramsToMap: (P) -> Map<String, Any?>,
109-
useResponsesApi: Boolean,
109+
private val paramsToMap: (P) -> Map<String, Any?>,
110+
private val useResponsesApi: Boolean,
110111
createOne: (P) -> R,
111112
createTwo: (P, RequestOptions) -> R,
113+
private val streamOne: (P) -> StreamResponse<C>,
114+
private val streamTwo: (P, RequestOptions) -> StreamResponse<C>,
115+
private val aggregator: java.util.function.Function<List<Any?>, Any?>,
112116
) {
117+
private val tracedConfig = config.toBuilder().name("ChatOpenAI").runType(RunType.LLM).build()
118+
119+
private val outputProcessor =
120+
Function<R, Map<String, Any?>> { response -> processChatCompletionOutput(toMap(response)) }
121+
113122
val oneArg: Function<P, R> =
114123
traceable(
115-
Function<P, R> {
116-
setInvocationParams(paramsToMap(it), useResponsesApi = useResponsesApi)
117-
createOne(it)
124+
Function<P, R> { p ->
125+
setInvocationParams(paramsToMap(p), useResponsesApi = useResponsesApi)
126+
createOne(p)
118127
},
119-
config
128+
tracedConfig
120129
.toBuilder()
121-
.name("ChatOpenAI")
122-
.runType(RunType.LLM)
123130
.processTracedIO(
124131
TraceProcessIO<P, R>(
125132
processInputs = Function { params -> paramsToMap(params) },
126-
processOutputs =
127-
Function { response -> processChatCompletionOutput(toMap(response)) },
133+
processOutputs = outputProcessor,
128134
)
129135
)
130136
.build(),
@@ -136,22 +142,77 @@ private class TracedOpenAIClientCreate<P, R : Any>(
136142
setInvocationParams(paramsToMap(params), useResponsesApi = useResponsesApi)
137143
createTwo(params, opts)
138144
},
139-
config
145+
tracedConfig
140146
.toBuilder()
141-
.name("ChatOpenAI")
142-
.runType(RunType.LLM)
143147
.processTracedIO(
144148
TraceProcessIO<Pair<P, RequestOptions>, R>(
145149
processInputs =
146150
Function { (params, opts) ->
147151
paramsToMap(params) + ("request_options" to toGenericMap(opts))
148152
},
149-
processOutputs =
150-
Function { response -> processChatCompletionOutput(toMap(response)) },
153+
processOutputs = outputProcessor,
151154
)
152155
)
153156
.build(),
154157
)
158+
159+
private val streamOutputProcessor =
160+
Function<Any, Map<String, Any?>> { processChatCompletionOutput(toMap(it!!)) }
161+
162+
/**
163+
* Traces a streaming call. The traced function returns the inner [java.util.stream.Stream]
164+
* (which `traceable` instruments via `peek` + `onClose`), then we rewrap it in a
165+
* [StreamResponse] that delegates `close()` to both the instrumented stream and the original
166+
* response.
167+
*/
168+
fun streamOneArg(params: P): StreamResponse<C> {
169+
val response = streamOne(params)
170+
val tracedStreamFn =
171+
traceable(
172+
Function<P, java.util.stream.Stream<C>> { p ->
173+
setInvocationParams(paramsToMap(p), useResponsesApi = useResponsesApi)
174+
response.stream()
175+
},
176+
tracedConfig
177+
.toBuilder()
178+
.processTracedIO(
179+
TraceProcessIO<P, Any>(
180+
processInputs = Function { p -> paramsToMap(p) },
181+
processOutputs = streamOutputProcessor,
182+
aggregator = aggregator,
183+
)
184+
)
185+
.build(),
186+
)
187+
val tracedStream = tracedStreamFn.apply(params)
188+
return TracedStreamResponse(tracedStream, response)
189+
}
190+
191+
fun streamTwoArg(params: P, opts: RequestOptions): StreamResponse<C> {
192+
val response = streamTwo(params, opts)
193+
val tracedStreamFn =
194+
traceable(
195+
Function<P, java.util.stream.Stream<C>> { p ->
196+
setInvocationParams(paramsToMap(p), useResponsesApi = useResponsesApi)
197+
response.stream()
198+
},
199+
tracedConfig
200+
.toBuilder()
201+
.processTracedIO(
202+
TraceProcessIO<P, Any>(
203+
processInputs =
204+
Function { p ->
205+
paramsToMap(p) + ("request_options" to toGenericMap(opts))
206+
},
207+
processOutputs = streamOutputProcessor,
208+
aggregator = aggregator,
209+
)
210+
)
211+
.build(),
212+
)
213+
val tracedStream = tracedStreamFn.apply(params)
214+
return TracedStreamResponse(tracedStream, response)
215+
}
155216
}
156217

157218
private class TracedChatCompletionService(
@@ -166,6 +227,9 @@ private class TracedChatCompletionService(
166227
useResponsesApi = false,
167228
createOne = delegate::create,
168229
createTwo = delegate::create,
230+
streamOne = delegate::createStreaming,
231+
streamTwo = delegate::createStreaming,
232+
aggregator = CHAT_COMPLETION_AGGREGATOR,
169233
)
170234

171235
override fun create(params: ChatCompletionCreateParams): ChatCompletion =
@@ -176,6 +240,15 @@ private class TracedChatCompletionService(
176240
requestOptions: RequestOptions,
177241
): ChatCompletion = traced.twoArg.apply(params, requestOptions)
178242

243+
override fun createStreaming(
244+
params: ChatCompletionCreateParams
245+
): StreamResponse<ChatCompletionChunk> = traced.streamOneArg(params)
246+
247+
override fun createStreaming(
248+
params: ChatCompletionCreateParams,
249+
requestOptions: RequestOptions,
250+
): StreamResponse<ChatCompletionChunk> = traced.streamTwoArg(params, requestOptions)
251+
179252
override fun withOptions(modifier: Consumer<ClientOptions.Builder>): ChatCompletionService =
180253
TracedChatCompletionService(delegate.withOptions(modifier), config)
181254
}
@@ -192,13 +265,61 @@ private class TracedResponseService(
192265
useResponsesApi = true,
193266
createOne = delegate::create,
194267
createTwo = delegate::create,
268+
streamOne = delegate::createStreaming,
269+
streamTwo = delegate::createStreaming,
270+
aggregator = RESPONSES_AGGREGATOR,
195271
)
196272

197273
override fun create(params: ResponseCreateParams): Response = traced.oneArg.apply(params)
198274

199275
override fun create(params: ResponseCreateParams, requestOptions: RequestOptions): Response =
200276
traced.twoArg.apply(params, requestOptions)
201277

278+
override fun createStreaming(
279+
params: ResponseCreateParams
280+
): StreamResponse<ResponseStreamEvent> = traced.streamOneArg(params)
281+
282+
override fun createStreaming(
283+
params: ResponseCreateParams,
284+
requestOptions: RequestOptions,
285+
): StreamResponse<ResponseStreamEvent> = traced.streamTwoArg(params, requestOptions)
286+
202287
override fun withOptions(modifier: Consumer<ClientOptions.Builder>): ResponseService =
203288
TracedResponseService(delegate.withOptions(modifier), config)
204289
}
290+
291+
/**
292+
* A [StreamResponse] that returns an instrumented [java.util.stream.Stream] (with tracing
293+
* peek/onClose handlers) and delegates [close] to the original response.
294+
*/
295+
private class TracedStreamResponse<C>(
296+
private val tracedStream: java.util.stream.Stream<C>,
297+
private val original: StreamResponse<C>,
298+
) : StreamResponse<C> {
299+
300+
override fun stream(): java.util.stream.Stream<C> = tracedStream
301+
302+
override fun close() {
303+
try {
304+
tracedStream.close()
305+
} finally {
306+
original.close()
307+
}
308+
}
309+
}
310+
311+
/** Aggregator that uses the SDK's [ChatCompletionAccumulator] to reassemble chunks. */
312+
private val CHAT_COMPLETION_AGGREGATOR =
313+
java.util.function.Function<List<Any?>, Any?> { chunks ->
314+
val acc = ChatCompletionAccumulator.create()
315+
chunks.filterIsInstance<ChatCompletionChunk>().forEach { acc.accumulate(it) }
316+
acc.chatCompletion()
317+
}
318+
319+
/** Aggregator that uses the SDK's [ResponseAccumulator] to reassemble events. */
320+
private val RESPONSES_AGGREGATOR =
321+
java.util.function.Function<List<Any?>, Any?> { events ->
322+
val acc = ResponseAccumulator.create()
323+
events.filterIsInstance<ResponseStreamEvent>().forEach { acc.accumulate(it) }
324+
acc.response()
325+
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package com.langchain.smith.testutils
2+
3+
import com.langchain.smith.client.LangsmithClient
4+
import com.langchain.smith.client.okhttp.LangsmithOkHttpClient
5+
import com.langchain.smith.models.runs.Run
6+
import com.langchain.smith.models.runs.RunIngestBatchParams
7+
import com.langchain.smith.models.runs.RunIngestBatchResponse
8+
import com.langchain.smith.services.blocking.RunService
9+
import java.lang.reflect.Proxy
10+
11+
/**
12+
* A test [LangsmithClient] that captures [Run] data from `ingestBatch` calls for assertion.
13+
*
14+
* If `LANGSMITH_API_KEY` is set, runs are also forwarded to the real LangSmith server.
15+
*/
16+
internal class CapturingLangsmithClient {
17+
18+
val postedRuns = mutableListOf<Run>()
19+
val patchedRuns = mutableListOf<Run>()
20+
21+
private val realClient: LangsmithClient? =
22+
if (!System.getenv("LANGSMITH_API_KEY").isNullOrBlank()) {
23+
LangsmithOkHttpClient.fromEnv()
24+
} else {
25+
null
26+
}
27+
28+
fun awaitAndGetPostedRuns(delayMs: Long = 500): List<Run> {
29+
Thread.sleep(delayMs)
30+
return postedRuns.toList()
31+
}
32+
33+
fun awaitAndGetPatchedRuns(delayMs: Long = 500): List<Run> {
34+
Thread.sleep(delayMs)
35+
return patchedRuns.toList()
36+
}
37+
38+
val client: LangsmithClient = createProxy()
39+
40+
private fun createProxy(): LangsmithClient {
41+
val runService =
42+
Proxy.newProxyInstance(
43+
RunService::class.java.classLoader,
44+
arrayOf(RunService::class.java),
45+
) { _, method, args ->
46+
if (method.name == "ingestBatch" && args != null && args.isNotEmpty()) {
47+
val params = args[0]
48+
if (params is RunIngestBatchParams) {
49+
params.post().ifPresent { postedRuns.addAll(it) }
50+
params.patch().ifPresent { patchedRuns.addAll(it) }
51+
}
52+
// Forward to real LangSmith if available
53+
if (realClient != null) {
54+
method.invoke(realClient.runs(), *args)
55+
} else {
56+
RunIngestBatchResponse.builder().build()
57+
}
58+
} else if (method.name == "withRawResponse" || method.name == "withOptions") {
59+
throw NotImplementedError()
60+
} else {
61+
// Forward unknown methods to real client if available
62+
if (realClient != null && args != null) {
63+
method.invoke(realClient.runs(), *args)
64+
} else if (realClient != null) {
65+
method.invoke(realClient.runs())
66+
} else {
67+
throw NotImplementedError(
68+
"CapturingLangsmithClient: ${method.name} not supported without LANGSMITH_API_KEY"
69+
)
70+
}
71+
}
72+
} as RunService
73+
74+
return Proxy.newProxyInstance(
75+
LangsmithClient::class.java.classLoader,
76+
arrayOf(LangsmithClient::class.java),
77+
) { _, method, _ ->
78+
when (method.name) {
79+
"runs" -> runService
80+
"close" -> {
81+
realClient?.close()
82+
Unit
83+
}
84+
else ->
85+
throw NotImplementedError(
86+
"CapturingLangsmithClient only supports runs() and close()"
87+
)
88+
}
89+
} as LangsmithClient
90+
}
91+
}

0 commit comments

Comments
 (0)