@@ -120,25 +120,47 @@ async def run_query(question: str, kb_dir: Path, model: str, stream: bool = Fals
120120 result = await Runner .run (agent , question , max_turns = MAX_TURNS )
121121 return result .final_output or ""
122122
123+ use_color = sys .stdout .isatty ()
124+
125+ if use_color :
126+ from rich .console import Console
127+ from rich .live import Live
128+ from rich .markdown import Markdown
129+ console = Console ()
130+ live = Live (console = console , vertical_overflow = "visible" )
131+ live .start ()
132+ else :
133+ live = None
134+
123135 result = Runner .run_streamed (agent , question , max_turns = MAX_TURNS )
124- collected = []
125- async for event in result .stream_events ():
126- if isinstance (event , RawResponsesStreamEvent ):
127- if isinstance (event .data , ResponseTextDeltaEvent ):
128- text = event .data .delta
129- if text :
130- sys .stdout .write (text )
136+ collected : list [str ] = []
137+ try :
138+ async for event in result .stream_events ():
139+ if isinstance (event , RawResponsesStreamEvent ):
140+ if isinstance (event .data , ResponseTextDeltaEvent ):
141+ text = event .data .delta
142+ if text :
143+ collected .append (text )
144+ if live :
145+ live .update (Markdown ("" .join (collected )))
146+ else :
147+ sys .stdout .write (text )
148+ sys .stdout .flush ()
149+ elif isinstance (event , RunItemStreamEvent ):
150+ item = event .item
151+ if item .type == "tool_call_item" :
152+ raw = item .raw_item
153+ args = getattr (raw , "arguments" , "{}" )
154+ if live :
155+ live .stop ()
156+ sys .stdout .write (f"\n [tool call] { raw .name } ({ args } )\n \n " )
131157 sys .stdout .flush ()
132- collected .append (text )
133- elif isinstance (event , RunItemStreamEvent ):
134- item = event .item
135- if item .type == "tool_call_item" :
136- raw = item .raw_item
137- args = getattr (raw , "arguments" , "{}" )
138- sys .stdout .write (f"\n [tool call] { raw .name } ({ args } )\n \n " )
139- sys .stdout .flush ()
140- elif item .type == "tool_call_output_item" :
141- pass
142- sys .stdout .write ("\n " )
143- sys .stdout .flush ()
158+ if live :
159+ live .start ()
160+ elif item .type == "tool_call_output_item" :
161+ pass
162+ finally :
163+ if live :
164+ live .stop ()
165+ print ()
144166 return "" .join (collected ) if collected else result .final_output or ""
0 commit comments