diff --git a/openkb/agent/_markdown.py b/openkb/agent/_markdown.py new file mode 100644 index 0000000..6fa96e3 --- /dev/null +++ b/openkb/agent/_markdown.py @@ -0,0 +1,371 @@ +"""Markdown rendering in Claude Code's terminal style. + +Mirrors claude-code's utils/markdown.ts: parse with markdown-it, then map +each token to Rich primitives. No colors for plain text / bold / italic -- +just terminal styling. Headings are left-aligned. +""" + +from __future__ import annotations + +from typing import Any + +from markdown_it import MarkdownIt +from markdown_it.tree import SyntaxTreeNode +from rich.console import Group, RenderableType +from rich.syntax import Syntax +from rich.text import Text + + +INLINE_CODE_STYLE = "blue" +BLOCKQUOTE_BAR = "\u258e" + + +_MD = MarkdownIt("commonmark").enable("table") + + +def render(content: str) -> RenderableType: + tokens = _MD.parse(content) + tree = SyntaxTreeNode(tokens) + + blocks: list[RenderableType] = [] + for child in tree.children: + rendered = _render_block(child) + if rendered is not None: + blocks.append(rendered) + + if not blocks: + return Text("") + parts: list[RenderableType] = [blocks[0]] + for block in blocks[1:]: + parts.append(Text("")) + parts.append(block) + return Group(*parts) + + +def _render_block(node: Any) -> RenderableType | None: + t = node.type + if t == "heading": + depth = int(node.tag[1:]) + text = _render_inline_container(node) + if depth == 1: + text.stylize("bold italic underline") + else: + text.stylize("bold") + return text + if t == "paragraph": + return _render_inline_container(node) + if t == "fence": + info_parts = (node.info or "").strip().split() + lang = info_parts[0] if info_parts else "" + return Syntax( + node.content.rstrip("\n"), + lang or "text", + theme="monokai", + background_color="default", + word_wrap=True, + ) + if t == "code_block": + return Syntax( + node.content.rstrip("\n"), + "text", + theme="monokai", + background_color="default", + word_wrap=True, + ) + if t == "hr": + return Text("---") + if t in ("bullet_list", "ordered_list"): + return _render_list(node, ordered=(t == "ordered_list"), depth=0) + if t == "blockquote": + return _render_blockquote(node) + if t == "table": + return _render_table(node) + if t == "html_block": + return Text(node.content.rstrip("\n")) + return None + + +def _render_inline_container(node: Any) -> Text: + if not node.children: + return Text("") + inline = node.children[0] + out = Text() + for child in inline.children or []: + _append_inline(child, out) + return out + + +def _append_inline(node: Any, out: Text) -> None: + t = node.type + if t == "text": + out.append(node.content) + elif t == "softbreak": + out.append("\n") + elif t == "hardbreak": + out.append("\n") + elif t == "strong": + piece = Text() + for child in node.children or []: + _append_inline(child, piece) + piece.stylize("bold") + out.append_text(piece) + elif t == "em": + piece = Text() + for child in node.children or []: + _append_inline(child, piece) + piece.stylize("italic") + out.append_text(piece) + elif t == "code_inline": + out.append(node.content, style=INLINE_CODE_STYLE) + elif t == "link": + href = node.attrGet("href") or "" + piece = Text() + for child in node.children or []: + _append_inline(child, piece) + if href.startswith("mailto:"): + email = href[len("mailto:") :] + plain = piece.plain + if plain and plain != email and plain != href: + piece.stylize(f"link {href}") + out.append_text(piece) + else: + out.append(email, style=f"link {href}") + return + if href: + plain = piece.plain + if plain and plain != href: + piece.stylize(f"link {href}") + out.append_text(piece) + else: + out.append(href, style=f"link {href}") + else: + out.append_text(piece) + elif t == "image": + href = node.attrGet("src") or "" + out.append(href) + elif t in ("html_inline", "html_block"): + out.append(node.content) + else: + content = getattr(node, "content", "") + if content: + out.append(content) + + +def _append_with_cont_indent(target: Text, source: Text, cont_indent: str) -> None: + lines = source.split("\n", allow_blank=True) + for i, line in enumerate(lines): + if i > 0: + target.append("\n" + cont_indent) + target.append_text(line) + + +def _render_code_as_text(node: Any) -> Text: + return Text(node.content.rstrip("\n"), style="dim") + + +def _render_list(node: Any, ordered: bool, depth: int) -> Text: + result = Text() + items = list(node.children) + start = 1 + if ordered: + try: + start = int(node.attrGet("start") or 1) + except (TypeError, ValueError): + start = 1 + + for i, item in enumerate(items): + indent = " " * depth + cont = indent + " " + if ordered: + prefix = f"{_list_number(depth, start + i)}. " + else: + prefix = "- " + result.append(indent + prefix) + first = True + for child in item.children or []: + if child.type == "paragraph": + if not first: + result.append("\n" + cont) + _append_with_cont_indent(result, _render_inline_container(child), cont) + first = False + elif child.type in ("bullet_list", "ordered_list"): + result.append("\n") + result.append_text( + _render_list( + child, + ordered=(child.type == "ordered_list"), + depth=depth + 1, + ) + ) + elif child.type in ("fence", "code_block"): + if not first: + result.append("\n" + cont) + _append_with_cont_indent(result, _render_code_as_text(child), cont) + first = False + else: + rendered = _render_block(child) + if rendered is None: + continue + if not first: + result.append("\n" + cont) + if isinstance(rendered, Text): + _append_with_cont_indent(result, rendered, cont) + first = False + if i < len(items) - 1: + result.append("\n") + return result + + +def _list_number(depth: int, n: int) -> str: + if depth == 0: + return str(n) + if depth == 1: + return _to_letters(n) + if depth == 2: + return _to_roman(n) + return str(n) + + +def _to_letters(n: int) -> str: + result = "" + while n > 0: + n -= 1 + result = chr(ord("a") + (n % 26)) + result + n //= 26 + return result or "a" + + +_ROMAN = [ + (1000, "m"), + (900, "cm"), + (500, "d"), + (400, "cd"), + (100, "c"), + (90, "xc"), + (50, "l"), + (40, "xl"), + (10, "x"), + (9, "ix"), + (5, "v"), + (4, "iv"), + (1, "i"), +] + + +def _to_roman(n: int) -> str: + out = "" + for value, numeral in _ROMAN: + while n >= value: + out += numeral + n -= value + return out + + +def _render_blockquote(node: Any) -> Text: + inner_blocks: list[Text] = [] + for child in node.children or []: + if child.type in ("fence", "code_block"): + inner_blocks.append(_render_code_as_text(child)) + continue + rendered = _render_block(child) + if isinstance(rendered, Text): + inner_blocks.append(rendered) + + combined = Text() + for i, block in enumerate(inner_blocks): + if i > 0: + combined.append("\n\n") + combined.append_text(block) + combined.stylize("italic") + + lines = combined.split("\n", allow_blank=True) + out = Text() + for i, line in enumerate(lines): + if i > 0: + out.append("\n") + if line.plain.strip(): + out.append(f"{BLOCKQUOTE_BAR} ", style="dim") + out.append_text(line) + else: + out.append_text(line) + return out + + +def _render_table(node: Any) -> Text: + header_row: list[Text] = [] + rows: list[list[Text]] = [] + aligns: list[str | None] = [] + + thead = next((c for c in node.children if c.type == "thead"), None) + tbody = next((c for c in node.children if c.type == "tbody"), None) + + if thead and thead.children: + tr = thead.children[0] + for th in tr.children or []: + header_row.append(_render_inline_container(th)) + aligns.append(th.attrGet("style")) + if tbody: + for tr in tbody.children or []: + row: list[Text] = [] + for td in tr.children or []: + row.append(_render_inline_container(td)) + rows.append(row) + + if not header_row: + return Text("") + + widths = [max(3, cell.cell_len) for cell in header_row] + for row in rows: + for i, cell in enumerate(row): + if i < len(widths): + widths[i] = max(widths[i], cell.cell_len) + + out = Text() + out.append("| ") + for i, cell in enumerate(header_row): + out.append_text(_pad(cell, widths[i], aligns[i] if i < len(aligns) else None)) + out.append(" | ") + out = _rstrip_trailing_space(out) + out.append("\n|") + for w in widths: + out.append("-" * (w + 2)) + out.append("|") + for row in rows: + out.append("\n| ") + for i, cell in enumerate(row): + width = widths[i] if i < len(widths) else cell.cell_len + align = aligns[i] if i < len(aligns) else None + out.append_text(_pad(cell, width, align)) + out.append(" | ") + out = _rstrip_trailing_space(out) + return out + + +def _pad(cell: Text, width: int, align: str | None) -> Text: + padding = max(0, width - cell.cell_len) + if not padding: + return cell + if align and "center" in align: + left = padding // 2 + right = padding - left + out = Text(" " * left) + out.append_text(cell) + out.append(" " * right) + return out + if align and "right" in align: + out = Text(" " * padding) + out.append_text(cell) + return out + out = Text() + out.append_text(cell) + out.append(" " * padding) + return out + + +def _rstrip_trailing_space(text: Text) -> Text: + plain = text.plain + stripped = plain.rstrip(" ") + trim = len(plain) - len(stripped) + if trim: + return text[: len(plain) - trim] + return text diff --git a/openkb/agent/chat.py b/openkb/agent/chat.py index 42ac9f9..bd67c62 100644 --- a/openkb/agent/chat.py +++ b/openkb/agent/chat.py @@ -189,7 +189,22 @@ def _make_prompt_session(session: ChatSession, style: Style, use_color: bool) -> ) -async def _run_turn(agent: Any, session: ChatSession, user_input: str, style: Style) -> None: +def _make_rich_console() -> Any: + from rich.console import Console + + return Console() + + +def _make_markdown(text: str) -> Any: + from openkb.agent._markdown import render + + return render(text) + + +async def _run_turn( + agent: Any, session: ChatSession, user_input: str, style: Style, + *, use_color: bool = True, raw: bool = False, +) -> None: """Run one agent turn with streaming output and persist the new history.""" from agents import ( RawResponsesStreamEvent, @@ -202,11 +217,29 @@ async def _run_turn(agent: Any, session: ChatSession, user_input: str, style: St result = Runner.run_streamed(agent, new_input, max_turns=MAX_TURNS) - sys.stdout.write("\n") - sys.stdout.flush() + print() collected: list[str] = [] + segment: list[str] = [] last_was_text = False need_blank_before_text = False + + if use_color and not raw: + from rich.console import Console + from rich.live import Live + + console = _make_rich_console() + else: + console = None # type: ignore[assignment] + + def _start_live() -> Any: + if console is None: + return None + lv = Live(console=console, vertical_overflow="visible") + lv.start() + return lv + + live = _start_live() + try: async for event in result.stream_events(): if isinstance(event, RawResponsesStreamEvent): @@ -214,27 +247,52 @@ async def _run_turn(agent: Any, session: ChatSession, user_input: str, style: St text = event.data.delta if text: if need_blank_before_text: - sys.stdout.write("\n") + if console is not None: + print() + segment = [] + live = _start_live() + else: + sys.stdout.write("\n") need_blank_before_text = False - sys.stdout.write(text) - sys.stdout.flush() collected.append(text) + segment.append(text) last_was_text = True + if live: + if "\n" in text: + joined = "".join(segment) + visible = joined[: joined.rfind("\n") + 1] + if visible: + live.update(_make_markdown(visible)) + else: + sys.stdout.write(text) + sys.stdout.flush() elif isinstance(event, RunItemStreamEvent): item = event.item if item.type == "tool_call_item": if last_was_text: - sys.stdout.write("\n") - sys.stdout.flush() + if live: + if segment: + live.update(_make_markdown("".join(segment))) + live.stop() + live = None + else: + sys.stdout.write("\n") + sys.stdout.flush() last_was_text = False - raw = item.raw_item - name = getattr(raw, "name", "?") - args = getattr(raw, "arguments", "") or "" + raw_item = item.raw_item + name = getattr(raw_item, "name", "?") + args = getattr(raw_item, "arguments", "") or "" + if live: + live.stop() + live = None _fmt(style, ("class:tool", _format_tool_line(name, args) + "\n")) need_blank_before_text = True finally: - sys.stdout.write("\n\n") - sys.stdout.flush() + if live: + if segment: + live.update(_make_markdown("".join(segment))) + live.stop() + print() answer = "".join(collected).strip() if not answer: @@ -319,6 +377,7 @@ async def run_chat( session: ChatSession, *, no_color: bool = False, + raw: bool = False, ) -> None: """Run the chat REPL against ``session`` until the user exits.""" from openkb.config import load_config @@ -371,7 +430,7 @@ async def run_chat( append_log(kb_dir / "wiki", "query", user_input) try: - await _run_turn(agent, session, user_input, style) + await _run_turn(agent, session, user_input, style, use_color=use_color, raw=raw) except KeyboardInterrupt: _fmt(style, ("class:error", "\n[aborted]\n")) except Exception as exc: diff --git a/openkb/agent/query.py b/openkb/agent/query.py index 39e0e40..762c314 100644 --- a/openkb/agent/query.py +++ b/openkb/agent/query.py @@ -91,7 +91,14 @@ def get_image(image_path: str) -> ToolOutputImage | ToolOutputText: ) -async def run_query(question: str, kb_dir: Path, model: str, stream: bool = False) -> str: +async def run_query( + question: str, + kb_dir: Path, + model: str, + stream: bool = False, + *, + raw: bool = False, +) -> str: """Run a Q&A query against the knowledge base. Args: @@ -99,6 +106,8 @@ async def run_query(question: str, kb_dir: Path, model: str, stream: bool = Fals kb_dir: Root of the knowledge base. model: LLM model name. stream: If True, print response tokens to stdout as they arrive. + raw: If True, write raw markdown source instead of rendering it + (still keeps tool-call line styling). Returns: The agent's final answer as a string. @@ -120,25 +129,93 @@ async def run_query(question: str, kb_dir: Path, model: str, stream: bool = Fals result = await Runner.run(agent, question, max_turns=MAX_TURNS) return result.final_output or "" + import os + use_color = sys.stdout.isatty() and not os.environ.get("NO_COLOR", "") + + from openkb.agent.chat import ( + _build_style, + _fmt, + _format_tool_line, + _make_markdown, + _make_rich_console, + ) + + style = _build_style(use_color) + + from rich.live import Live + + if use_color and not raw: + console = _make_rich_console() + else: + console = None # type: ignore[assignment] + + def _start_live() -> Live | None: + if console is None: + return None + lv = Live(console=console, vertical_overflow="visible") + lv.start() + return lv + + live: Live | None = None + last_was_text = False + need_blank_before_text = False result = Runner.run_streamed(agent, question, max_turns=MAX_TURNS) - collected = [] - async for event in result.stream_events(): - if isinstance(event, RawResponsesStreamEvent): - if isinstance(event.data, ResponseTextDeltaEvent): - text = event.data.delta - if text: - sys.stdout.write(text) - sys.stdout.flush() - collected.append(text) - elif isinstance(event, RunItemStreamEvent): - item = event.item - if item.type == "tool_call_item": - raw = item.raw_item - args = getattr(raw, "arguments", "{}") - sys.stdout.write(f"\n[tool call] {raw.name}({args})\n\n") - sys.stdout.flush() - elif item.type == "tool_call_output_item": - pass - sys.stdout.write("\n") - sys.stdout.flush() + collected: list[str] = [] + segment: list[str] = [] + try: + live = _start_live() + async for event in result.stream_events(): + if isinstance(event, RawResponsesStreamEvent): + if isinstance(event.data, ResponseTextDeltaEvent): + text = event.data.delta + if text: + if need_blank_before_text: + if console is not None: + print() + segment = [] + live = _start_live() + else: + sys.stdout.write("\n") + need_blank_before_text = False + collected.append(text) + segment.append(text) + last_was_text = True + if live: + if "\n" in text: + joined = "".join(segment) + visible = joined[: joined.rfind("\n") + 1] + if visible: + live.update(_make_markdown(visible)) + else: + sys.stdout.write(text) + sys.stdout.flush() + elif isinstance(event, RunItemStreamEvent): + item = event.item + if item.type == "tool_call_item": + if last_was_text: + if live: + if segment: + live.update(_make_markdown("".join(segment))) + live.stop() + live = None + else: + sys.stdout.write("\n") + sys.stdout.flush() + last_was_text = False + raw_item = item.raw_item + name = getattr(raw_item, "name", "?") + args = getattr(raw_item, "arguments", "") or "" + if live: + live.stop() + live = None + _fmt(style, ("class:tool", _format_tool_line(name, args) + "\n")) + need_blank_before_text = True + elif item.type == "tool_call_output_item": + pass + finally: + if live: + if segment: + live.update(_make_markdown("".join(segment))) + live.stop() + print() return "".join(collected) if collected else result.final_output or "" diff --git a/openkb/cli.py b/openkb/cli.py index d91789f..1336e23 100644 --- a/openkb/cli.py +++ b/openkb/cli.py @@ -360,8 +360,13 @@ def add(ctx, path): @cli.command() @click.argument("question") @click.option("--save", is_flag=True, default=False, help="Save the answer to wiki/explorations/.") +@click.option( + "--raw", "raw", + is_flag=True, default=False, + help="Show raw markdown source instead of rendered output (keeps tool-call colors).", +) @click.pass_context -def query(ctx, question, save): +def query(ctx, question, save, raw): """Query the knowledge base with QUESTION.""" kb_dir = _find_kb_dir(ctx.obj.get("kb_dir_override")) if kb_dir is None: @@ -376,7 +381,7 @@ def query(ctx, question, save): model: str = config.get("model", DEFAULT_CONFIG["model"]) try: - answer = asyncio.run(run_query(question, kb_dir, model, stream=True)) + answer = asyncio.run(run_query(question, kb_dir, model, stream=True, raw=raw)) except Exception as exc: click.echo(f"[ERROR] Query failed: {exc}") return @@ -416,8 +421,13 @@ def query(ctx, question, save): is_flag=True, default=False, help="Disable colored output.", ) +@click.option( + "--raw", "raw", + is_flag=True, default=False, + help="Show raw markdown source instead of rendered output (keeps prompt and tool-call colors).", +) @click.pass_context -def chat(ctx, resume, list_sessions_flag, delete_id, no_color): +def chat(ctx, resume, list_sessions_flag, delete_id, no_color, raw): """Start an interactive chat with the knowledge base.""" kb_dir = _find_kb_dir(ctx.obj.get("kb_dir_override")) if kb_dir is None: @@ -491,7 +501,7 @@ def chat(ctx, resume, list_sessions_flag, delete_id, no_color): from openkb.agent.chat import run_chat try: - asyncio.run(run_chat(kb_dir, session, no_color=no_color)) + asyncio.run(run_chat(kb_dir, session, no_color=no_color, raw=raw)) except Exception as exc: click.echo(f"[ERROR] Chat failed: {exc}") diff --git a/pyproject.toml b/pyproject.toml index 4af87be..e368a97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "python-dotenv", "json-repair", "prompt_toolkit>=3.0", + "rich>=13.0", ] [project.urls] diff --git a/tests/test_markdown_renderer.py b/tests/test_markdown_renderer.py new file mode 100644 index 0000000..3cbe80e --- /dev/null +++ b/tests/test_markdown_renderer.py @@ -0,0 +1,38 @@ +from rich.console import Group +from rich.text import Text + +from openkb.agent._markdown import render + + +def _group_text(renderable: Group) -> list[str]: + return [part.plain for part in renderable.renderables if isinstance(part, Text)] + + +def test_render_preserves_inline_html(): + rendered = render("hello
world") + + assert isinstance(rendered, Group) + assert _group_text(rendered) == ["hello
world"] + + +def test_render_preserves_inline_html_tags(): + rendered = render("H2O and x2") + + assert isinstance(rendered, Group) + assert _group_text(rendered) == ["H2O and x2"] + + +def test_render_preserves_html_block(): + rendered = render("
\nMore\nHidden text\n
") + + assert isinstance(rendered, Group) + assert _group_text(rendered) == [ + "
\nMore\nHidden text\n
", + ] + + +def test_render_keeps_html_block_between_paragraphs(): + rendered = render("before\n\n
hello
\n\nafter") + + assert isinstance(rendered, Group) + assert _group_text(rendered) == ["before", "", "
hello
", "", "after"]