@@ -10,10 +10,15 @@ import com.langchain.smith.tracing.traceable
1010import com.openai.client.OpenAIClient
1111import com.openai.core.ClientOptions
1212import com.openai.core.RequestOptions
13+ import com.openai.core.http.StreamResponse
14+ import com.openai.helpers.ChatCompletionAccumulator
15+ import com.openai.helpers.ResponseAccumulator
1316import com.openai.models.chat.completions.ChatCompletion
17+ import com.openai.models.chat.completions.ChatCompletionChunk
1418import com.openai.models.chat.completions.ChatCompletionCreateParams
1519import com.openai.models.responses.Response
1620import com.openai.models.responses.ResponseCreateParams
21+ import com.openai.models.responses.ResponseStreamEvent
1722import com.openai.services.blocking.ChatService
1823import com.openai.services.blocking.ResponseService
1924import com.openai.services.blocking.chat.ChatCompletionService
@@ -23,11 +28,10 @@ import java.util.function.Function
2328/* *
2429 * Wraps an [OpenAIClient] with LangSmith tracing.
2530 *
26- * The returned client traces `chat().completions().create()` and `responses().create()` calls using
27- * [traceable], recording inputs and outputs as JSON maps. All other methods — including
28- * `createStreaming()` — pass through to the delegate unchanged.
29- *
30- * // TODO: Trace `createStreaming()` for both chat completions and responses.
31+ * The returned client traces `chat().completions().create()`, `responses().create()`, and their
32+ * streaming counterparts (`createStreaming()`) using [traceable], recording inputs and outputs as
33+ * JSON maps. Streaming calls aggregate chunks using the SDK's built-in accumulators and record the
34+ * final aggregated result when the stream closes.
3135 *
3236 * The [config] is used as the base configuration for all traced calls — the wrapper overrides
3337 * [TraceConfig.name] and [TraceConfig.runType] per operation.
@@ -55,7 +59,8 @@ import java.util.function.Function
5559 * @param config base tracing configuration (client, project, tags, etc.)
5660 * @return a wrapped [OpenAIClient] that traces LLM calls to LangSmith
5761 */
58- fun wrapOpenAI (client : OpenAIClient , config : TraceConfig ): OpenAIClient =
62+ @JvmOverloads
63+ fun wrapOpenAI (client : OpenAIClient , config : TraceConfig = TraceConfig ()): OpenAIClient =
5964 TracedOpenAIClient (client, config)
6065
6166private class TracedOpenAIClient (
@@ -90,41 +95,42 @@ private class TracedChatService(
9095}
9196
9297/* *
93- * Builds the traced `create` functions (1-arg and 2-arg) for an OpenAI service.
98+ * Builds the traced `create` and `createStreaming` functions for an OpenAI service.
9499 *
95100 * Both [TracedChatCompletionService] and [TracedResponseService] follow the same pattern — this
96101 * extracts the shared tracing plumbing so each service only provides the type-specific bits.
97102 *
98103 * @param P the params type (e.g. [ChatCompletionCreateParams])
99104 * @param R the response type (e.g. [ChatCompletion])
100- * @param config base tracing configuration
101- * @param paramsToMap serializes params to a map for recording as run inputs
102- * @param useResponsesApi whether this is the responses API (affects metadata)
103- * @param createOne calls the underlying service's `create(params)` method
104- * @param createTwo calls the underlying service's `create(params, requestOptions)` method
105+ * @param C the chunk/event type for streaming (e.g. [ChatCompletionChunk])
105106 */
106- private class TracedOpenAIClientCreate <P , R : Any >(
107+ private class TracedOpenAIClientCreate <P , R : Any , C >(
107108 config : TraceConfig ,
108- paramsToMap : (P ) -> Map <String , Any ?>,
109- useResponsesApi : Boolean ,
109+ private val paramsToMap : (P ) -> Map <String , Any ?>,
110+ private val useResponsesApi : Boolean ,
110111 createOne : (P ) -> R ,
111112 createTwo : (P , RequestOptions ) -> R ,
113+ private val streamOne : (P ) -> StreamResponse <C >,
114+ private val streamTwo : (P , RequestOptions ) -> StreamResponse <C >,
115+ private val aggregator : java.util.function.Function <List <Any ?>, Any? >,
112116) {
117+ private val tracedConfig = config.toBuilder().name(" ChatOpenAI" ).runType(RunType .LLM ).build()
118+
119+ private val outputProcessor =
120+ Function <R , Map <String , Any ?>> { response -> processChatCompletionOutput(toMap(response)) }
121+
113122 val oneArg: Function <P , R > =
114123 traceable(
115- Function <P , R > {
116- setInvocationParams(paramsToMap(it ), useResponsesApi = useResponsesApi)
117- createOne(it )
124+ Function <P , R > { p ->
125+ setInvocationParams(paramsToMap(p ), useResponsesApi = useResponsesApi)
126+ createOne(p )
118127 },
119- config
128+ tracedConfig
120129 .toBuilder()
121- .name(" ChatOpenAI" )
122- .runType(RunType .LLM )
123130 .processTracedIO(
124131 TraceProcessIO <P , R >(
125132 processInputs = Function { params -> paramsToMap(params) },
126- processOutputs =
127- Function { response -> processChatCompletionOutput(toMap(response)) },
133+ processOutputs = outputProcessor,
128134 )
129135 )
130136 .build(),
@@ -136,22 +142,77 @@ private class TracedOpenAIClientCreate<P, R : Any>(
136142 setInvocationParams(paramsToMap(params), useResponsesApi = useResponsesApi)
137143 createTwo(params, opts)
138144 },
139- config
145+ tracedConfig
140146 .toBuilder()
141- .name(" ChatOpenAI" )
142- .runType(RunType .LLM )
143147 .processTracedIO(
144148 TraceProcessIO <Pair <P , RequestOptions >, R > (
145149 processInputs =
146150 Function { (params, opts) ->
147151 paramsToMap(params) + (" request_options" to toGenericMap(opts))
148152 },
149- processOutputs =
150- Function { response -> processChatCompletionOutput(toMap(response)) },
153+ processOutputs = outputProcessor,
151154 )
152155 )
153156 .build(),
154157 )
158+
159+ private val streamOutputProcessor =
160+ Function <Any , Map <String , Any ?>> { processChatCompletionOutput(toMap(it!! )) }
161+
162+ /* *
163+ * Traces a streaming call. The traced function returns the inner [java.util.stream.Stream]
164+ * (which `traceable` instruments via `peek` + `onClose`), then we rewrap it in a
165+ * [StreamResponse] that delegates `close()` to both the instrumented stream and the original
166+ * response.
167+ */
168+ fun streamOneArg (params : P ): StreamResponse <C > {
169+ val response = streamOne(params)
170+ val tracedStreamFn =
171+ traceable(
172+ Function <P , java.util.stream.Stream <C >> { p ->
173+ setInvocationParams(paramsToMap(p), useResponsesApi = useResponsesApi)
174+ response.stream()
175+ },
176+ tracedConfig
177+ .toBuilder()
178+ .processTracedIO(
179+ TraceProcessIO <P , Any >(
180+ processInputs = Function { p -> paramsToMap(p) },
181+ processOutputs = streamOutputProcessor,
182+ aggregator = aggregator,
183+ )
184+ )
185+ .build(),
186+ )
187+ val tracedStream = tracedStreamFn.apply (params)
188+ return TracedStreamResponse (tracedStream, response)
189+ }
190+
191+ fun streamTwoArg (params : P , opts : RequestOptions ): StreamResponse <C > {
192+ val response = streamTwo(params, opts)
193+ val tracedStreamFn =
194+ traceable(
195+ Function <P , java.util.stream.Stream <C >> { p ->
196+ setInvocationParams(paramsToMap(p), useResponsesApi = useResponsesApi)
197+ response.stream()
198+ },
199+ tracedConfig
200+ .toBuilder()
201+ .processTracedIO(
202+ TraceProcessIO <P , Any >(
203+ processInputs =
204+ Function { p ->
205+ paramsToMap(p) + (" request_options" to toGenericMap(opts))
206+ },
207+ processOutputs = streamOutputProcessor,
208+ aggregator = aggregator,
209+ )
210+ )
211+ .build(),
212+ )
213+ val tracedStream = tracedStreamFn.apply (params)
214+ return TracedStreamResponse (tracedStream, response)
215+ }
155216}
156217
157218private class TracedChatCompletionService (
@@ -166,6 +227,9 @@ private class TracedChatCompletionService(
166227 useResponsesApi = false ,
167228 createOne = delegate::create,
168229 createTwo = delegate::create,
230+ streamOne = delegate::createStreaming,
231+ streamTwo = delegate::createStreaming,
232+ aggregator = CHAT_COMPLETION_AGGREGATOR ,
169233 )
170234
171235 override fun create (params : ChatCompletionCreateParams ): ChatCompletion =
@@ -176,6 +240,15 @@ private class TracedChatCompletionService(
176240 requestOptions : RequestOptions ,
177241 ): ChatCompletion = traced.twoArg.apply (params, requestOptions)
178242
243+ override fun createStreaming (
244+ params : ChatCompletionCreateParams
245+ ): StreamResponse <ChatCompletionChunk > = traced.streamOneArg(params)
246+
247+ override fun createStreaming (
248+ params : ChatCompletionCreateParams ,
249+ requestOptions : RequestOptions ,
250+ ): StreamResponse <ChatCompletionChunk > = traced.streamTwoArg(params, requestOptions)
251+
179252 override fun withOptions (modifier : Consumer <ClientOptions .Builder >): ChatCompletionService =
180253 TracedChatCompletionService (delegate.withOptions(modifier), config)
181254}
@@ -192,13 +265,61 @@ private class TracedResponseService(
192265 useResponsesApi = true ,
193266 createOne = delegate::create,
194267 createTwo = delegate::create,
268+ streamOne = delegate::createStreaming,
269+ streamTwo = delegate::createStreaming,
270+ aggregator = RESPONSES_AGGREGATOR ,
195271 )
196272
197273 override fun create (params : ResponseCreateParams ): Response = traced.oneArg.apply (params)
198274
199275 override fun create (params : ResponseCreateParams , requestOptions : RequestOptions ): Response =
200276 traced.twoArg.apply (params, requestOptions)
201277
278+ override fun createStreaming (
279+ params : ResponseCreateParams
280+ ): StreamResponse <ResponseStreamEvent > = traced.streamOneArg(params)
281+
282+ override fun createStreaming (
283+ params : ResponseCreateParams ,
284+ requestOptions : RequestOptions ,
285+ ): StreamResponse <ResponseStreamEvent > = traced.streamTwoArg(params, requestOptions)
286+
202287 override fun withOptions (modifier : Consumer <ClientOptions .Builder >): ResponseService =
203288 TracedResponseService (delegate.withOptions(modifier), config)
204289}
290+
291+ /* *
292+ * A [StreamResponse] that returns an instrumented [java.util.stream.Stream] (with tracing
293+ * peek/onClose handlers) and delegates [close] to the original response.
294+ */
295+ private class TracedStreamResponse <C >(
296+ private val tracedStream : java.util.stream.Stream <C >,
297+ private val original : StreamResponse <C >,
298+ ) : StreamResponse<C> {
299+
300+ override fun stream (): java.util.stream.Stream <C > = tracedStream
301+
302+ override fun close () {
303+ try {
304+ tracedStream.close()
305+ } finally {
306+ original.close()
307+ }
308+ }
309+ }
310+
311+ /* * Aggregator that uses the SDK's [ChatCompletionAccumulator] to reassemble chunks. */
312+ private val CHAT_COMPLETION_AGGREGATOR =
313+ java.util.function.Function <List <Any ?>, Any? > { chunks ->
314+ val acc = ChatCompletionAccumulator .create()
315+ chunks.filterIsInstance<ChatCompletionChunk >().forEach { acc.accumulate(it) }
316+ acc.chatCompletion()
317+ }
318+
319+ /* * Aggregator that uses the SDK's [ResponseAccumulator] to reassemble events. */
320+ private val RESPONSES_AGGREGATOR =
321+ java.util.function.Function <List <Any ?>, Any? > { events ->
322+ val acc = ResponseAccumulator .create()
323+ events.filterIsInstance<ResponseStreamEvent >().forEach { acc.accumulate(it) }
324+ acc.response()
325+ }
0 commit comments