Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 206 additions & 23 deletions openclaw-tinker/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ def _flatten_content(content) -> str:


def _normalize_messages(messages: list[dict]) -> list[dict]:
"""Normalize messages for the chat template (developer -> system, flatten content)."""
"""Normalize messages for the chat template.

Qwen3.5's chat template expects replayed tool_call.arguments to be a dict.
"""
out = []
for msg in messages:
m = dict(msg)
Expand All @@ -70,6 +73,26 @@ def _normalize_messages(messages: list[dict]) -> list[dict]:
raw = m.get("content")
if not isinstance(raw, str) and raw is not None:
m["content"] = _flatten_content(raw)
tool_calls = m.get("tool_calls")
if isinstance(tool_calls, list):
normalized_calls = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
normalized_calls.append(tool_call)
continue
tool_call_copy = dict(tool_call)
function = tool_call_copy.get("function")
if isinstance(function, dict):
function_copy = dict(function)
arguments = function_copy.get("arguments")
if isinstance(arguments, str):
try:
function_copy["arguments"] = json.loads(arguments)
except (json.JSONDecodeError, TypeError, ValueError):
function_copy["arguments"] = {}
tool_call_copy["function"] = function_copy
normalized_calls.append(tool_call_copy)
m["tool_calls"] = normalized_calls
out.append(m)
return out

Expand All @@ -90,26 +113,139 @@ def _extract_logprobs(choice: dict[str, Any]) -> list[float]:
r"<\|tool_call_argument_begin\|>\s*(\{.*?\})\s*<\|tool_call_end\|>",
re.DOTALL,
)
_QWEN_TC_RE = re.compile(r"<tool_call>\s*(.*?)\s*</tool_call>", re.DOTALL)
_QWEN_JSON_TC_RE = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL)
_QWEN_XML_TC_RE = re.compile(
r"<tool_call>\s*<function=([^>]+)>(.*?)</function>\s*</tool_call>",
re.DOTALL,
)
_KIMI_TOOL_CALL_SECTION_RE = re.compile(
r"<\|tool_calls_section_(?:begin|end)\|>",
re.DOTALL,
)


def _resolve_enable_thinking(body: dict[str, Any]) -> bool:
enable_thinking = True
extra_body = body.get("extra_body")
if isinstance(extra_body, dict):
chat_template_kwargs = extra_body.get("chat_template_kwargs")
if isinstance(chat_template_kwargs, dict) and "enable_thinking" in chat_template_kwargs:
return bool(chat_template_kwargs["enable_thinking"])
if "enable_thinking" in extra_body:
return bool(extra_body["enable_thinking"])
chat_template_kwargs = body.get("chat_template_kwargs")
if isinstance(chat_template_kwargs, dict) and "enable_thinking" in chat_template_kwargs:
return bool(chat_template_kwargs["enable_thinking"])
if "enable_thinking" in body:
return bool(body["enable_thinking"])
return enable_thinking


def _apply_chat_template_with_fallbacks(tokenizer, messages: list[dict], *, enable_thinking: bool, tools=None, tool_choice=None) -> str:
optional_items: list[tuple[str, Any]] = [("enable_thinking", enable_thinking)]
if tools:
optional_items.append(("tools", tools))
if tool_choice is not None and tools:
optional_items.append(("tool_choice", tool_choice))

current_items = list(optional_items)
while True:
kwargs = {"tokenize": False, "add_generation_prompt": True}
kwargs.update({key: value for key, value in current_items})
try:
return tokenizer.apply_chat_template(messages, **kwargs)
except TypeError:
if not current_items:
raise
current_items.pop()


def _stringify_json_arguments(arguments_text: str) -> str:
try:
return json.dumps(json.loads(arguments_text), ensure_ascii=False)
except (json.JSONDecodeError, TypeError, ValueError):
return arguments_text


def _remove_spans(text: str, spans: list[tuple[int, int]]) -> str:
if not spans:
return text
parts: list[str] = []
cursor = 0
for start, end in sorted(spans):
if start < cursor:
continue
parts.append(text[cursor:start])
cursor = end
parts.append(text[cursor:])
remaining = "".join(parts)
remaining = _KIMI_TOOL_CALL_SECTION_RE.sub("", remaining)
remaining = re.sub(r"\n{3,}", "\n\n", remaining)
return remaining.strip()


def _strip_thinking(text: str) -> str:
stripped = _THINK_RE.sub("", text)
return stripped.replace("<think>", "").replace("</think>", "").strip()


def _split_thinking(text: str) -> tuple[str | None, str]:
stripped = text.lstrip()

full_block_match = re.match(r"^<think>\s*(.*?)\s*</think>\s*(.*)$", stripped, flags=re.DOTALL)
if full_block_match:
reasoning = full_block_match.group(1).strip() or None
return reasoning, full_block_match.group(2).strip()

truncated_block_match = re.match(r"^<think>\s*(.*)$", stripped, flags=re.DOTALL)
if truncated_block_match:
reasoning = truncated_block_match.group(1).strip() or None
return reasoning, ""

orphan_close_match = re.match(r"^(.*?)</think>\s*(.*)$", stripped, flags=re.DOTALL)
if orphan_close_match:
reasoning = orphan_close_match.group(1).replace("<think>", "").strip() or None
return reasoning, orphan_close_match.group(2).strip()

return None, text


def _extract_tool_calls(text: str) -> tuple[str, list[dict]]:
"""Parse tool-call tags from assistant text into OpenAI-style tool_calls."""
if not text:
return "", []
stripped = _strip_thinking(text)
tool_calls = []
for i, m in enumerate(_KIMI_TC_RE.finditer(text)):
spans: list[tuple[int, int]] = []
for i, m in enumerate(_KIMI_TC_RE.finditer(stripped)):
raw_name = (m.group(1) or "").strip()
args_raw = (m.group(2) or "{}").strip()
try:
args_str = json.dumps(json.loads(args_raw), ensure_ascii=False)
except Exception:
args_str = args_raw
args_str = _stringify_json_arguments((m.group(2) or "{}").strip())
tool_calls.append({
"id": f"call_{i}", "type": "function",
"function": {"name": raw_name or "unknown_tool", "arguments": args_str},
})
for i, m in enumerate(_QWEN_TC_RE.finditer(text), start=len(tool_calls)):
spans.append(m.span())
for i, m in enumerate(_QWEN_XML_TC_RE.finditer(stripped), start=len(tool_calls)):
func_name = (m.group(1) or "").strip() or "unknown_tool"
body = (m.group(2) or "").strip()
arguments = {}
for param_match in re.finditer(
r"<parameter=([^>]+)>\s*(.*?)\s*</parameter>",
body,
re.DOTALL,
):
param_name = param_match.group(1).strip()
param_value = param_match.group(2).strip()
try:
arguments[param_name] = json.loads(param_value)
except (json.JSONDecodeError, TypeError, ValueError):
arguments[param_name] = param_value
tool_calls.append({
"id": f"call_{i}", "type": "function",
"function": {"name": func_name, "arguments": json.dumps(arguments, ensure_ascii=False)},
})
spans.append(m.span())
for i, m in enumerate(_QWEN_JSON_TC_RE.finditer(stripped), start=len(tool_calls)):
try:
payload = json.loads(m.group(1).strip())
except Exception:
Expand All @@ -122,12 +258,10 @@ def _extract_tool_calls(text: str) -> tuple[str, list[dict]]:
"id": f"call_{i}", "type": "function",
"function": {"name": str(name), "arguments": args},
})
clean = _THINK_RE.sub("", text)
clean = clean.replace("</think>", "")
clean = re.sub(r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>", "", clean, flags=re.DOTALL)
clean = re.sub(r"<\|tool_calls_section_begin\|>.*?<\|tool_calls_section_end\|>", "", clean, flags=re.DOTALL)
clean = _QWEN_TC_RE.sub("", clean)
return clean.strip(), tool_calls
spans.append(m.span())
if tool_calls:
return _remove_spans(stripped, spans), tool_calls
return stripped, []


# ===========================================================================
Expand Down Expand Up @@ -241,12 +375,20 @@ async def _forward_to_tinker(self, body: dict[str, Any]) -> dict[str, Any]:
messages = body.get("messages", [])
norm_msgs = _normalize_messages(messages)
tools = body.get("tools")
tool_choice = body.get("tool_choice")
parse_tool_calls = tools is not None and tool_choice != "none"
template_tools = tools if parse_tool_calls else None
enable_thinking = _resolve_enable_thinking(body)
temperature = float(body.get("temperature", 0.6))
max_tokens = int(body.get("max_tokens") or 2048)
stop = body.get("stop")

prompt_text = self._tokenizer.apply_chat_template(
norm_msgs, tools=tools, tokenize=False, add_generation_prompt=True,
prompt_text = _apply_chat_template_with_fallbacks(
self._tokenizer,
norm_msgs,
enable_thinking=enable_thinking,
tools=template_tools,
tool_choice=tool_choice if parse_tool_calls else None,
)
prompt_ids = self._tokenizer.encode(prompt_text, add_special_tokens=False)

Expand All @@ -268,10 +410,21 @@ async def _forward_to_tinker(self, body: dict[str, Any]) -> dict[str, Any]:
raw_response_logprobs = [float(lp) for lp in (seq.logprobs or [])]

response_text = self._tokenizer.decode(seq.tokens, skip_special_tokens=True)
normalized_text, parsed_tool_calls = _extract_tool_calls(response_text)
reasoning_content, visible_text = _split_thinking(response_text)
if not enable_thinking:
reasoning_content = None
if parse_tool_calls:
normalized_text, parsed_tool_calls = _extract_tool_calls(visible_text)
else:
normalized_text, parsed_tool_calls = visible_text, []

lp_content = [{"token": "", "logprob": lp, "top_logprobs": []} for lp in raw_response_logprobs]
assistant_message: dict[str, Any] = {"role": "assistant", "content": normalized_text}
assistant_message: dict[str, Any] = {
"role": "assistant",
"content": normalized_text if (normalized_text or not parsed_tool_calls) else None,
}
if reasoning_content is not None:
assistant_message["reasoning_content"] = reasoning_content
if parsed_tool_calls:
assistant_message["tool_calls"] = parsed_tool_calls

Expand Down Expand Up @@ -416,16 +569,46 @@ async def _stream_response(self, result: dict):
payload = result["response"]
choice = payload.get("choices", [{}])[0]
message = choice.get("message", {})
delta = {"role": "assistant", "content": message.get("content", "") or ""}
if message.get("tool_calls"):
delta["tool_calls"] = message["tool_calls"]
base = {
"id": payload.get("id", ""), "object": "chat.completion.chunk",
"created": payload.get("created", int(time.time())),
"model": payload.get("model", ""),
"session_id": payload.get("session_id", ""),
}
yield f"data: {json.dumps({**base, 'choices': [{'index': 0, 'delta': delta, 'finish_reason': None}]}, ensure_ascii=False)}\n\n"
tool_calls = message.get("tool_calls") or []
content = message.get("content")
if tool_calls:
for i, tool_call in enumerate(tool_calls):
delta: dict[str, Any] = {"role": "assistant"} if i == 0 else {}
if i == 0 and content:
delta["content"] = content
delta["tool_calls"] = [{
"index": i,
"id": tool_call["id"],
"type": "function",
"function": {
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
},
}]
yield (
"data: "
+ json.dumps(
{**base, "choices": [{"index": 0, "delta": delta, "finish_reason": None}]},
ensure_ascii=False,
)
+ "\n\n"
)
else:
delta = {"role": "assistant", "content": content or ""}
yield (
"data: "
+ json.dumps(
{**base, "choices": [{"index": 0, "delta": delta, "finish_reason": None}]},
ensure_ascii=False,
)
+ "\n\n"
)
yield f"data: {json.dumps({**base, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': choice.get('finish_reason', 'stop')}]}, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"

Expand Down
37 changes: 29 additions & 8 deletions openclaw-tinker/data_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,28 @@ def _build_datum(all_tokens: list[int], logprobs: list[float], advantages: list[
# RL / OPD datum conversion
# ---------------------------------------------------------------------------

def sample_to_datum(sample: TrainingSample, advantage: float):
def sample_to_datum(sample: TrainingSample, advantage: float, max_tokens: int = 0):
"""Convert one sample + scalar advantage into a Tinker Datum (RL / OPD).

For OPD samples with teacher_logprobs, the advantage is augmented with
per-token distillation signal: (teacher_lp - student_lp).
This matches Slime's --advantage-estimator on_policy_distillation where
advantage = teacher_logp - old_logp (raw, no coefficient).
"""
prompt_len = len(sample.prompt_tokens)
all_tokens = sample.prompt_tokens + sample.response_tokens
prompt_tokens = sample.prompt_tokens
# Truncate prompt from left if total exceeds max_tokens (keep response intact)
if max_tokens > 0:
total = len(prompt_tokens) + len(sample.response_tokens)
if total > max_tokens:
keep = max(1, max_tokens - len(sample.response_tokens))
trimmed = len(prompt_tokens) - keep
prompt_tokens = prompt_tokens[-keep:]
logger.info(
"[DataFormatter] truncated prompt: %d -> %d tokens (session=%s turn=%d)",
trimmed + keep, keep, sample.session_id, sample.turn_num,
)
prompt_len = len(prompt_tokens)
all_tokens = prompt_tokens + sample.response_tokens

logprobs = [0.0] * (prompt_len - 1) + list(sample.response_logprobs)
resp_advantages = [advantage * float(m) for m in sample.loss_mask]
Expand All @@ -128,12 +140,13 @@ def sample_to_datum(sample: TrainingSample, advantage: float):
return _build_datum(all_tokens, logprobs, advantages, sample.session_id, sample.turn_num)


def batch_to_datums(batch: list[TrainingSample], advantages: list[float]) -> list:
def batch_to_datums(batch: list[TrainingSample], advantages: list[float],
max_tokens: int = 0) -> list:
"""Convert a batch of samples + per-sample scalar advantages to Tinker Datums."""
datums = []
for sample, adv in zip(batch, advantages):
try:
datums.append(sample_to_datum(sample, adv))
datums.append(sample_to_datum(sample, adv, max_tokens=max_tokens))
except Exception as e:
logger.error(
"[DataFormatter] FAILED to convert session=%s turn=%d: %s",
Expand All @@ -150,6 +163,7 @@ def sample_to_datum_combined(
sample: TrainingSample,
w_opd: float = 1.0,
w_rl: float = 1.0,
max_tokens: int = 0,
):
"""Convert one sample into a Tinker Datum with combined OPD+RL advantages.

Expand All @@ -160,8 +174,14 @@ def sample_to_datum_combined(
where teacher_advantages = teacher_logp - old_logp (token-level, raw)
and grpo_advantages = reward broadcast (scalar)
"""
prompt_len = len(sample.prompt_tokens)
all_tokens = sample.prompt_tokens + sample.response_tokens
prompt_tokens = sample.prompt_tokens
if max_tokens > 0:
total = len(prompt_tokens) + len(sample.response_tokens)
if total > max_tokens:
keep = max(1, max_tokens - len(sample.response_tokens))
prompt_tokens = prompt_tokens[-keep:]
prompt_len = len(prompt_tokens)
all_tokens = prompt_tokens + sample.response_tokens

logprobs = [0.0] * (prompt_len - 1) + list(sample.response_logprobs)

Expand Down Expand Up @@ -189,13 +209,14 @@ def batch_to_datums_combined(
batch: list[TrainingSample],
w_opd: float = 1.0,
w_rl: float = 1.0,
max_tokens: int = 0,
) -> list:
"""Convert a batch of samples to Tinker Datums with combined advantages."""
datums = []
for sample in batch:
try:
datums.append(sample_to_datum_combined(
sample, w_opd=w_opd, w_rl=w_rl,
sample, w_opd=w_opd, w_rl=w_rl, max_tokens=max_tokens,
))
except Exception as e:
logger.error(
Expand Down
Loading