From 119206f2099d36541e6234cabc2acf9c359c6485 Mon Sep 17 00:00:00 2001 From: miaoziyang Date: Sat, 21 Mar 2026 09:24:04 +0800 Subject: [PATCH 1/2] feat: add Qwen3.5-4B model support --- openclaw-combine/README.md | 16 +- .../run_qwen35_4b_openclaw_combine.sh | 218 ++++++++++ .../run_qwen35_4b_openclaw_combine_lora.sh | 201 +++++++++ openclaw-opd/README.md | 20 +- openclaw-opd/run_qwen35_4b_openclaw_opd.sh | 209 ++++++++++ .../run_qwen35_4b_openclaw_opd_topk.sh | 206 ++++++++++ .../run_qwen35_4b_openclaw_opd_topk_lora.sh | 195 +++++++++ openclaw-rl/README.md | 12 +- openclaw-rl/run_qwen35_4b_openclaw_rl.sh | 210 ++++++++++ openclaw-rl/run_qwen35_4b_openclaw_rl_lora.sh | 193 +++++++++ openclaw-test/README.md | 18 +- slime/README.md | 2 +- slime/scripts/models/qwen3.5-4B.sh | 25 ++ slime/slime/backends/fsdp_utils/actor.py | 2 +- .../megatron_utils/megatron_to_hf/__init__.py | 3 + .../megatron_utils/megatron_to_hf/qwen3_5.py | 136 +++++++ .../backends/megatron_utils/model_provider.py | 2 + slime/slime/backends/sglang_utils/qwen3_5.py | 382 ++++++++++++++++++ .../backends/sglang_utils/sglang_engine.py | 33 +- slime/slime_plugins/mbridge/__init__.py | 3 +- slime/slime_plugins/mbridge/qwen3_5.py | 57 +++ .../slime_plugins/megatron_bridge/__init__.py | 1 + .../slime_plugins/megatron_bridge/qwen3_5.py | 173 ++++++++ slime/slime_plugins/models/qwen3_5.py | 211 ++++++++++ slime/slime_plugins/sglang_models/__init__.py | 0 .../sglang_models/qwen3_5_causal.py | 9 + 26 files changed, 2509 insertions(+), 28 deletions(-) create mode 100644 openclaw-combine/run_qwen35_4b_openclaw_combine.sh create mode 100644 openclaw-combine/run_qwen35_4b_openclaw_combine_lora.sh create mode 100644 openclaw-opd/run_qwen35_4b_openclaw_opd.sh create mode 100644 openclaw-opd/run_qwen35_4b_openclaw_opd_topk.sh create mode 100644 openclaw-opd/run_qwen35_4b_openclaw_opd_topk_lora.sh create mode 100644 openclaw-rl/run_qwen35_4b_openclaw_rl.sh create mode 100644 openclaw-rl/run_qwen35_4b_openclaw_rl_lora.sh create mode 100644 slime/scripts/models/qwen3.5-4B.sh create mode 100644 slime/slime/backends/megatron_utils/megatron_to_hf/qwen3_5.py create mode 100644 slime/slime/backends/sglang_utils/qwen3_5.py create mode 100644 slime/slime_plugins/mbridge/qwen3_5.py create mode 100644 slime/slime_plugins/megatron_bridge/qwen3_5.py create mode 100644 slime/slime_plugins/models/qwen3_5.py create mode 100644 slime/slime_plugins/sglang_models/__init__.py create mode 100644 slime/slime_plugins/sglang_models/qwen3_5_causal.py diff --git a/openclaw-combine/README.md b/openclaw-combine/README.md index db8067b6f0..f852594144 100644 --- a/openclaw-combine/README.md +++ b/openclaw-combine/README.md @@ -40,7 +40,12 @@ For each main-line turn, after the next state arrives: ```bash cd slime +# Qwen3 bash ../openclaw-combine/run_qwen3_4b_openclaw_combine.sh +# Qwen3.5 +bash ../openclaw-combine/run_qwen35_4b_openclaw_combine.sh +# Qwen3.5 LoRA (4 GPUs) +bash ../openclaw-combine/run_qwen35_4b_openclaw_combine_lora.sh ``` ### Key Environment Variables @@ -58,9 +63,10 @@ All other variables (`NUM_GPUS`, `ACTOR_GPUS`, `HF_CKPT`, etc.) are shared with ```text openclaw-combine/ ├── README.md -├── run_qwen3_4b_openclaw_combine.sh # Launch script -├── openclaw_combine_api_server.py # Async proxy: hint judge + PRM eval + sample submission -├── openclaw_combine_rollout.py # Rollout bridge to SLIME trainer -├── combine_loss.py # Weighted advantage: w_rl * GRPO + w_opd * teacher -└── results/ # Runtime records (auto-created) +├── run_qwen3_4b_openclaw_combine.sh # Launch script (Qwen3) +├── run_qwen35_4b_openclaw_combine.sh # Launch script (Qwen3.5) +├── openclaw_combine_api_server.py # Async proxy: hint judge + PRM eval + sample submission +├── openclaw_combine_rollout.py # Rollout bridge to SLIME trainer +├── combine_loss.py # Weighted advantage: w_rl * GRPO + w_opd * teacher +└── results/ # Runtime records (auto-created) ``` diff --git a/openclaw-combine/run_qwen35_4b_openclaw_combine.sh b/openclaw-combine/run_qwen35_4b_openclaw_combine.sh new file mode 100644 index 0000000000..0faef59e36 --- /dev/null +++ b/openclaw-combine/run_qwen35_4b_openclaw_combine.sh @@ -0,0 +1,218 @@ +#!/bin/bash +# Qwen3.5-4B full fine-tuning with OpenClaw combined method (Binary RL + OPD) +# Uses Megatron backend with tensor parallelism + +SKIP_CLUSTER_CLEANUP=${SKIP_CLUSTER_CLEANUP:-0} +if [ "${SKIP_CLUSTER_CLEANUP}" != "1" ]; then + pkill -9 sglang + sleep 3 + ray stop --force + pkill -9 ray + pkill -9 python + sleep 3 + pkill -9 ray + pkill -9 python +fi + +set -ex + +export PYTHONUNBUFFERED=1 +export PYTHONFAULTHANDLER=1 +export FLASHINFER_WORKSPACE_BASE="${FLASHINFER_WORKSPACE_BASE:-/tmp}" + +NUM_GPUS=${NUM_GPUS:-8} +ACTOR_GPUS=${ACTOR_GPUS:-4} +ROLLOUT_GPUS=${ROLLOUT_GPUS:-2} +PRM_GPUS=${PRM_GPUS:-2} + +if (( ACTOR_GPUS + ROLLOUT_GPUS + PRM_GPUS > NUM_GPUS )); then + echo "ACTOR_GPUS + ROLLOUT_GPUS + PRM_GPUS must be <= NUM_GPUS" + echo "ACTOR_GPUS=${ACTOR_GPUS}, ROLLOUT_GPUS=${ROLLOUT_GPUS}, PRM_GPUS=${PRM_GPUS}, NUM_GPUS=${NUM_GPUS}" + exit 1 +fi + +export RAY_health_check_failure_threshold=20 +export RAY_health_check_period_ms=5000 +export RAY_health_check_timeout_ms=30000 +export RAY_num_heartbeats_timeout=60 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +SLIME_ROOT="$(cd -- "${SCRIPT_DIR}/../slime" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" +source "${SLIME_ROOT}/scripts/models/qwen3.5-4B.sh" + +HF_CKPT=${HF_CKPT:-${REPO_ROOT}/models/Qwen3.5-4B} +REF_LOAD=${REF_LOAD:-${HF_CKPT}} +SAVE_CKPT=${SAVE_CKPT:-${REPO_ROOT}/ckpt/qwen35-4b-openclaw-combine} +PRM_MODEL_PATH=${PRM_MODEL_PATH:-${HF_CKPT}} + +export SGLANG_API_KEY="${SGLANG_API_KEY}" +export SERVED_MODEL_NAME="qwen3.5-4b" +export HOST="0.0.0.0" +export PORT="30000" +export OPENCLAW_RECORD_ENABLED="${OPENCLAW_RECORD_ENABLED:-1}" +export OPENCLAW_RECORD_FILE="${SCRIPT_DIR}/results/qwen35_4b_record.jsonl" +export TP="2" +export CONTEXT_LENGTH="32768" +export MEM_FRACTION_STATIC="0.8" +export REASONING_PARSER="${REASONING_PARSER:-qwen3}" +export TOOL_CALL_PARSER="${TOOL_CALL_PARSER:-qwen25}" +export SGLANG_LANGUAGE_ONLY="${SGLANG_LANGUAGE_ONLY:-1}" +export PRM_M="${PRM_M:-1}" +export OPENCLAW_OPD_TEACHER_LP_MAX_CONCURRENCY="${OPENCLAW_OPD_TEACHER_LP_MAX_CONCURRENCY:-1}" +export OPENCLAW_COMBINE_W_RL="${OPENCLAW_COMBINE_W_RL:-1.0}" +export OPENCLAW_COMBINE_W_OPD="${OPENCLAW_COMBINE_W_OPD:-1.0}" + +CKPT_ARGS=( + --megatron-to-hf-mode bridge + --hf-checkpoint "${HF_CKPT}" + --ref-load "${REF_LOAD}" + --save "${SAVE_CKPT}" + --save-interval 100 + --rotary-base 5000000 +) + +ROLLOUT_ARGS=( + --disable-rollout-global-dataset + --rollout-function-path openclaw_combine_rollout.generate_rollout_openclaw_combine + + --num-rollout 100000000 + --rollout-batch-size 16 + --n-samples-per-prompt 1 + --rollout-max-response-len 8192 + --rollout-max-context-len 32768 + --rollout-temperature 0.6 + --reward-key score + + --num-steps-per-rollout 1 +) + +PERF_ARGS=( + --tensor-model-parallel-size 4 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 32768 + --log-probs-chunk-size 1024 +) + +COMBINE_ARGS=( + --advantage-estimator grpo + --disable-rewards-normalization + --loss-type custom_loss + --custom-loss-function-path combine_loss.combine_loss_function + --use-kl-loss + --kl-loss-coef 0.0 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-5 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +EVAL_ARGS=() + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 2 + --sglang-tool-call-parser "${TOOL_CALL_PARSER}" + --sglang-mem-fraction-static 0.8 + --sglang-context-length 32768 + --sglang-reasoning-parser "${REASONING_PARSER}" +) + +if [ "${SGLANG_LANGUAGE_ONLY}" = "1" ]; then + SGLANG_ARGS+=(--sglang-language-only) +fi + +PRM_ARGS=( + --prm-enable + --prm-num-gpus "${PRM_GPUS}" + --prm-num-gpus-per-engine 2 + --prm-model-path "${PRM_MODEL_PATH}" + --prm-m "${PRM_M}" + --prm-temperature "${PRM_TEMPERATURE:-0.6}" + --prm-max-new-tokens "${PRM_MAX_NEW_TOKENS:-8192}" +) + +CUSTOM_ARGS=( + --custom-generate-function-path openclaw_combine_api_server.generate + --custom-rm-path openclaw_combine_api_server.reward_func +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +USE_WANDB=${USE_WANDB:-1} +WANDB_PROJECT=${WANDB_PROJECT:-openclaw_rl} +WANDB_KEY_VALUE=${WANDB_KEY:-${WANDB_API_KEY:-}} +if [ "${USE_WANDB}" = "1" ] && [ -n "${WANDB_KEY_VALUE}" ]; then + WANDB_ARGS=( + --use-wandb + --wandb-project ${WANDB_PROJECT} + --wandb-group qwen35-4b-openclaw-combine + --wandb-key ${WANDB_KEY_VALUE} + ) +else + WANDB_ARGS=() +fi + +export OPENCLAW_EVAL_MODE="${OPENCLAW_EVAL_MODE:-1}" + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export no_proxy="127.0.0.1,${MASTER_ADDR}" +ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"${REPO_ROOT}/Megatron-LM/:${SCRIPT_DIR}:${SCRIPT_DIR}/../openclaw-opd:${SLIME_ROOT}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"FLASHINFER_WORKSPACE_BASE\": \"${FLASHINFER_WORKSPACE_BASE}\", + \"OPENCLAW_EVAL_MODE\": \"${OPENCLAW_EVAL_MODE}\", + \"OPENCLAW_COMBINE_W_RL\": \"${OPENCLAW_COMBINE_W_RL}\", + \"OPENCLAW_COMBINE_W_OPD\": \"${OPENCLAW_COMBINE_W_OPD}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 "${SLIME_ROOT}/train_async.py" \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node "${ACTOR_GPUS}" \ + --rollout-num-gpus "${ROLLOUT_GPUS}" \ + --num-gpus-per-node "${NUM_GPUS}" \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${COMBINE_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${CUSTOM_ARGS[@]} \ + ${PRM_ARGS[@]} diff --git a/openclaw-combine/run_qwen35_4b_openclaw_combine_lora.sh b/openclaw-combine/run_qwen35_4b_openclaw_combine_lora.sh new file mode 100644 index 0000000000..9e3cfa4e51 --- /dev/null +++ b/openclaw-combine/run_qwen35_4b_openclaw_combine_lora.sh @@ -0,0 +1,201 @@ +#!/bin/bash +# Qwen3.5-4B LoRA training with OpenClaw combined method (Binary RL + OPD) +# Fewer GPUs than full fine-tuning: default 4 GPUs (2 actor + 1 rollout + 1 PRM) +# Uses FSDP backend instead of Megatron TP + +SKIP_CLUSTER_CLEANUP=${SKIP_CLUSTER_CLEANUP:-0} +if [ "${SKIP_CLUSTER_CLEANUP}" != "1" ]; then + pkill -9 sglang + sleep 3 + ray stop --force + pkill -9 ray + pkill -9 python + sleep 3 + pkill -9 ray + pkill -9 python +fi + +set -ex + +export PYTHONUNBUFFERED=1 +export PYTHONFAULTHANDLER=1 +export FLASHINFER_WORKSPACE_BASE="${FLASHINFER_WORKSPACE_BASE:-/tmp}" + +NUM_GPUS=${NUM_GPUS:-4} +ACTOR_GPUS=${ACTOR_GPUS:-2} +ROLLOUT_GPUS=${ROLLOUT_GPUS:-1} +PRM_GPUS=${PRM_GPUS:-1} + +if (( ACTOR_GPUS + ROLLOUT_GPUS + PRM_GPUS > NUM_GPUS )); then + echo "ACTOR_GPUS + ROLLOUT_GPUS + PRM_GPUS must be <= NUM_GPUS" + echo "ACTOR_GPUS=${ACTOR_GPUS}, ROLLOUT_GPUS=${ROLLOUT_GPUS}, PRM_GPUS=${PRM_GPUS}, NUM_GPUS=${NUM_GPUS}" + exit 1 +fi + +export RAY_health_check_failure_threshold=20 +export RAY_health_check_period_ms=5000 +export RAY_health_check_timeout_ms=30000 +export RAY_num_heartbeats_timeout=60 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" +SLIME_ROOT="$(cd -- "${SCRIPT_DIR}/../slime" &>/dev/null && pwd)" + +HF_CKPT=${HF_CKPT:-${REPO_ROOT}/models/Qwen3.5-4B} +REF_LOAD=${REF_LOAD:-${HF_CKPT}} +SAVE_CKPT=${SAVE_CKPT:-${REPO_ROOT}/ckpt/qwen35-4b-openclaw-combine-lora} +PRM_MODEL_PATH=${PRM_MODEL_PATH:-${HF_CKPT}} + +export SGLANG_API_KEY="${SGLANG_API_KEY}" +export SERVED_MODEL_NAME="qwen3.5-4b-lora" +export HOST="0.0.0.0" +export PORT="30000" +export OPENCLAW_RECORD_ENABLED="${OPENCLAW_RECORD_ENABLED:-1}" +export OPENCLAW_RECORD_FILE="${SCRIPT_DIR}/results/qwen35_4b_lora_record.jsonl" +export TP="${TP:-1}" +export CONTEXT_LENGTH="32768" +export MEM_FRACTION_STATIC="0.85" +export REASONING_PARSER="${REASONING_PARSER:-qwen3}" +export TOOL_CALL_PARSER="${TOOL_CALL_PARSER:-qwen25}" +export SGLANG_LANGUAGE_ONLY="${SGLANG_LANGUAGE_ONLY:-1}" +export PRM_M="${PRM_M:-1}" +export OPENCLAW_OPD_TEACHER_LP_MAX_CONCURRENCY="${OPENCLAW_OPD_TEACHER_LP_MAX_CONCURRENCY:-1}" +export OPENCLAW_COMBINE_W_RL="${OPENCLAW_COMBINE_W_RL:-1.0}" +export OPENCLAW_COMBINE_W_OPD="${OPENCLAW_COMBINE_W_OPD:-1.0}" + +CKPT_ARGS=( + --hf-checkpoint "${HF_CKPT}" + --ref-load "${REF_LOAD}" + --save "${SAVE_CKPT}" + --save-interval 1 +) + +ROLLOUT_ARGS=( + --disable-rollout-global-dataset + --rollout-function-path openclaw_combine_rollout.generate_rollout_openclaw_combine + + --num-rollout 100000000 + --rollout-batch-size 16 + --n-samples-per-prompt 1 + --rollout-max-response-len 8192 + --rollout-max-context-len 32768 + --rollout-temperature 0.6 + --reward-key score + + --num-steps-per-rollout 1 +) + +PERF_ARGS=( + --use-dynamic-batch-size + --max-tokens-per-gpu "${MAX_TOKENS_PER_GPU:-8192}" + --gradient-checkpointing +) + +COMBINE_ARGS=( + --advantage-estimator grpo + --disable-rewards-normalization + --loss-type custom_loss + --custom-loss-function-path combine_loss.combine_loss_function + --use-kl-loss + --kl-loss-coef 0.0 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-5 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +LORA_ARGS=( + --use-lora + --lora-rank 16 + --lora-alpha 32 + --lora-target-modules "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj" +) + +EVAL_ARGS=() + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine "${TP}" + --sglang-tool-call-parser "${TOOL_CALL_PARSER}" + --sglang-mem-fraction-static 0.85 + --sglang-context-length 32768 + --sglang-reasoning-parser "${REASONING_PARSER}" +) + +if [ "${SGLANG_LANGUAGE_ONLY}" = "1" ]; then + SGLANG_ARGS+=(--sglang-language-only) +fi + +PRM_ARGS=( + --prm-enable + --prm-num-gpus "${PRM_GPUS}" + --prm-num-gpus-per-engine "${PRM_TP:-${TP}}" + --prm-model-path "${PRM_MODEL_PATH}" + --prm-m "${PRM_M}" + --prm-temperature "${PRM_TEMPERATURE:-0.6}" + --prm-max-new-tokens "${PRM_MAX_NEW_TOKENS:-4096}" +) + +CUSTOM_ARGS=( + --custom-generate-function-path openclaw_combine_api_server.generate + --custom-rm-path openclaw_combine_api_server.reward_func +) + +USE_WANDB=${USE_WANDB:-1} +WANDB_PROJECT=${WANDB_PROJECT:-openclaw_rl} +WANDB_KEY_VALUE=${WANDB_KEY:-${WANDB_API_KEY:-}} +if [ "${USE_WANDB}" = "1" ] && [ -n "${WANDB_KEY_VALUE}" ]; then + WANDB_ARGS=( + --use-wandb + --wandb-project ${WANDB_PROJECT} + --wandb-group qwen35-4b-openclaw-combine-lora + --wandb-key ${WANDB_KEY_VALUE} + ) +else + WANDB_ARGS=() +fi + +export OPENCLAW_EVAL_MODE="${OPENCLAW_EVAL_MODE:-1}" + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export no_proxy="127.0.0.1,${MASTER_ADDR}" +ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"${SCRIPT_DIR}:${SCRIPT_DIR}/../openclaw-opd:${SLIME_ROOT}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"FLASHINFER_WORKSPACE_BASE\": \"${FLASHINFER_WORKSPACE_BASE}\", + \"OPENCLAW_EVAL_MODE\": \"${OPENCLAW_EVAL_MODE}\", + \"OPENCLAW_COMBINE_W_RL\": \"${OPENCLAW_COMBINE_W_RL}\", + \"OPENCLAW_COMBINE_W_OPD\": \"${OPENCLAW_COMBINE_W_OPD}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 "${SLIME_ROOT}/train_async.py" \ + --train-backend fsdp \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node "${ACTOR_GPUS}" \ + --rollout-num-gpus "${ROLLOUT_GPUS}" \ + --num-gpus-per-node "${NUM_GPUS}" \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${COMBINE_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${CUSTOM_ARGS[@]} \ + ${PRM_ARGS[@]} \ + ${LORA_ARGS[@]} diff --git a/openclaw-opd/README.md b/openclaw-opd/README.md index 54add4e7d4..9c97654528 100644 --- a/openclaw-opd/README.md +++ b/openclaw-opd/README.md @@ -29,7 +29,10 @@ Default script: ```bash cd slime +# Qwen3 bash ../openclaw-opd/run_qwen3_4b_openclaw_opd.sh +# Qwen3.5 +bash ../openclaw-opd/run_qwen35_4b_openclaw_opd.sh ``` ## Option B: Top-K Logits Distillation (SDFT/SDPO-style) @@ -60,7 +63,10 @@ Top-K is implemented as an additive extension: ```bash cd slime +# Qwen3 bash ../openclaw-opd/run_qwen3_4b_openclaw_opd_topk.sh +# Qwen3.5 +bash ../openclaw-opd/run_qwen35_4b_openclaw_opd_topk.sh ``` Equivalent key args: @@ -78,10 +84,12 @@ Equivalent key args: ```text openclaw-opd/ ├── README.md -├── run_qwen3_4b_openclaw_opd.sh # Token-level OPD (default) -├── run_qwen3_4b_openclaw_opd_topk.sh # Top-K custom-loss path -├── topk_distillation_loss.py # Reverse-KL top-K loss (external custom loss) -├── openclaw_opd_api_server.py # Async judge + teacher query + sample submission -├── openclaw_opd_rollout.py # Rollout bridge to SLIME trainer -└── results/ # Runtime records (auto-created) +├── run_qwen3_4b_openclaw_opd.sh # Token-level OPD (Qwen3) +├── run_qwen35_4b_openclaw_opd.sh # Token-level OPD (Qwen3.5) +├── run_qwen3_4b_openclaw_opd_topk.sh # Top-K (Qwen3) +├── run_qwen35_4b_openclaw_opd_topk.sh # Top-K (Qwen3.5) +├── topk_distillation_loss.py # Reverse-KL top-K loss (external custom loss) +├── openclaw_opd_api_server.py # Async judge + teacher query + sample submission +├── openclaw_opd_rollout.py # Rollout bridge to SLIME trainer +└── results/ # Runtime records (auto-created) ``` diff --git a/openclaw-opd/run_qwen35_4b_openclaw_opd.sh b/openclaw-opd/run_qwen35_4b_openclaw_opd.sh new file mode 100644 index 0000000000..399f3708dd --- /dev/null +++ b/openclaw-opd/run_qwen35_4b_openclaw_opd.sh @@ -0,0 +1,209 @@ +#!/bin/bash +# Qwen3.5-4B full fine-tuning with OpenClaw OPD (on-policy distillation) +# Uses Megatron backend with tensor parallelism + +SKIP_CLUSTER_CLEANUP=${SKIP_CLUSTER_CLEANUP:-0} +if [ "${SKIP_CLUSTER_CLEANUP}" != "1" ]; then + pkill -9 sglang + sleep 3 + ray stop --force + pkill -9 ray + pkill -9 python + sleep 3 + pkill -9 ray + pkill -9 python +fi + +set -ex + +export PYTHONUNBUFFERED=1 +export PYTHONFAULTHANDLER=1 +export FLASHINFER_WORKSPACE_BASE="${FLASHINFER_WORKSPACE_BASE:-/tmp}" + +NUM_GPUS=${NUM_GPUS:-8} +ACTOR_GPUS=${ACTOR_GPUS:-4} +ROLLOUT_GPUS=${ROLLOUT_GPUS:-2} +PRM_GPUS=${PRM_GPUS:-2} + +if (( ACTOR_GPUS + ROLLOUT_GPUS + PRM_GPUS > NUM_GPUS )); then + echo "ACTOR_GPUS + ROLLOUT_GPUS + PRM_GPUS must be <= NUM_GPUS" + echo "ACTOR_GPUS=${ACTOR_GPUS}, ROLLOUT_GPUS=${ROLLOUT_GPUS}, PRM_GPUS=${PRM_GPUS}, NUM_GPUS=${NUM_GPUS}" + exit 1 +fi + +export RAY_health_check_failure_threshold=20 +export RAY_health_check_period_ms=5000 +export RAY_health_check_timeout_ms=30000 +export RAY_num_heartbeats_timeout=60 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +SLIME_ROOT="$(cd -- "${SCRIPT_DIR}/../slime" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" +source "${SLIME_ROOT}/scripts/models/qwen3.5-4B.sh" + +HF_CKPT=${HF_CKPT:-${REPO_ROOT}/models/Qwen3.5-4B} +REF_LOAD=${REF_LOAD:-${HF_CKPT}} +SAVE_CKPT=${SAVE_CKPT:-${REPO_ROOT}/ckpt/qwen35-4b-openclaw-opd} +PRM_MODEL_PATH=${PRM_MODEL_PATH:-${HF_CKPT}} + +export SGLANG_API_KEY="${SGLANG_API_KEY}" +export SERVED_MODEL_NAME="qwen3.5-4b" +export HOST="0.0.0.0" +export PORT="30000" +export OPENCLAW_RECORD_ENABLED="${OPENCLAW_RECORD_ENABLED:-1}" +export OPENCLAW_RECORD_FILE="${SCRIPT_DIR}/results/qwen35_4b_record.jsonl" +export TP="2" +export CONTEXT_LENGTH="32768" +export MEM_FRACTION_STATIC="0.8" +export REASONING_PARSER="${REASONING_PARSER:-qwen3}" +export TOOL_CALL_PARSER="${TOOL_CALL_PARSER:-qwen25}" +export SGLANG_LANGUAGE_ONLY="${SGLANG_LANGUAGE_ONLY:-1}" +export PRM_M="${PRM_M:-1}" +export OPENCLAW_OPD_TEACHER_LP_MAX_CONCURRENCY="${OPENCLAW_OPD_TEACHER_LP_MAX_CONCURRENCY:-1}" + +CKPT_ARGS=( + --megatron-to-hf-mode bridge + --hf-checkpoint "${HF_CKPT}" + --ref-load "${REF_LOAD}" + --save "${SAVE_CKPT}" + --save-interval 100 + --rotary-base 5000000 +) + +ROLLOUT_ARGS=( + --disable-rollout-global-dataset + --rollout-function-path openclaw_opd_rollout.generate_rollout_openclaw_opd + + --num-rollout 100000000 + --rollout-batch-size 16 + --n-samples-per-prompt 1 + --rollout-max-response-len 8192 + --rollout-max-context-len 32768 + --rollout-temperature 0.6 + --reward-key score + + --num-steps-per-rollout 1 +) + +PERF_ARGS=( + --tensor-model-parallel-size 4 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 32768 + --log-probs-chunk-size 1024 +) + +OPD_ARGS=( + --advantage-estimator on_policy_distillation + --use-kl-loss + --kl-loss-coef 0.0 + --kl-loss-type low_var_kl + --entropy-coef 0.00 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-5 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +EVAL_ARGS=() + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 2 + --sglang-tool-call-parser "${TOOL_CALL_PARSER}" + --sglang-mem-fraction-static 0.8 + --sglang-context-length 32768 + --sglang-reasoning-parser "${REASONING_PARSER}" +) + +if [ "${SGLANG_LANGUAGE_ONLY}" = "1" ]; then + SGLANG_ARGS+=(--sglang-language-only) +fi + +PRM_ARGS=( + --prm-enable + --prm-num-gpus "${PRM_GPUS}" + --prm-num-gpus-per-engine 2 + --prm-model-path "${PRM_MODEL_PATH}" + --prm-m "${PRM_M}" + --prm-temperature "${PRM_TEMPERATURE:-0.6}" + --prm-max-new-tokens "${PRM_MAX_NEW_TOKENS:-8192}" +) + +CUSTOM_ARGS=( + --custom-generate-function-path openclaw_opd_api_server.generate + --custom-rm-path openclaw_opd_api_server.reward_func +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +USE_WANDB=${USE_WANDB:-1} +WANDB_PROJECT=${WANDB_PROJECT:-openclaw_rl} +WANDB_KEY_VALUE=${WANDB_KEY:-${WANDB_API_KEY:-}} +if [ "${USE_WANDB}" = "1" ] && [ -n "${WANDB_KEY_VALUE}" ]; then + WANDB_ARGS=( + --use-wandb + --wandb-project ${WANDB_PROJECT} + --wandb-group qwen35-4b-openclaw-opd + --wandb-key ${WANDB_KEY_VALUE} + ) +else + WANDB_ARGS=() +fi + +export OPENCLAW_EVAL_MODE="${OPENCLAW_EVAL_MODE:-1}" + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export no_proxy="127.0.0.1,${MASTER_ADDR}" +ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"${REPO_ROOT}/Megatron-LM/:${SCRIPT_DIR}:${SLIME_ROOT}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"FLASHINFER_WORKSPACE_BASE\": \"${FLASHINFER_WORKSPACE_BASE}\", + \"OPENCLAW_EVAL_MODE\": \"${OPENCLAW_EVAL_MODE}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 "${SLIME_ROOT}/train_async.py" \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node "${ACTOR_GPUS}" \ + --rollout-num-gpus "${ROLLOUT_GPUS}" \ + --num-gpus-per-node "${NUM_GPUS}" \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${OPD_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${CUSTOM_ARGS[@]} \ + ${PRM_ARGS[@]} diff --git a/openclaw-opd/run_qwen35_4b_openclaw_opd_topk.sh b/openclaw-opd/run_qwen35_4b_openclaw_opd_topk.sh new file mode 100644 index 0000000000..8a11332bab --- /dev/null +++ b/openclaw-opd/run_qwen35_4b_openclaw_opd_topk.sh @@ -0,0 +1,206 @@ +#!/bin/bash +# Qwen3.5-4B full fine-tuning with OpenClaw OPD top-k distillation +# Uses Megatron backend with tensor parallelism + +SKIP_CLUSTER_CLEANUP=${SKIP_CLUSTER_CLEANUP:-0} +if [ "${SKIP_CLUSTER_CLEANUP}" != "1" ]; then + pkill -9 sglang + sleep 3 + ray stop --force + pkill -9 ray + pkill -9 python + sleep 3 + pkill -9 ray + pkill -9 python +fi + +set -ex + +export PYTHONUNBUFFERED=1 +export PYTHONFAULTHANDLER=1 +export FLASHINFER_WORKSPACE_BASE="${FLASHINFER_WORKSPACE_BASE:-/tmp}" + +NUM_GPUS=${NUM_GPUS:-8} +ACTOR_GPUS=${ACTOR_GPUS:-4} +ROLLOUT_GPUS=${ROLLOUT_GPUS:-2} +PRM_GPUS=${PRM_GPUS:-2} + +if (( ACTOR_GPUS + ROLLOUT_GPUS + PRM_GPUS > NUM_GPUS )); then + echo "ACTOR_GPUS + ROLLOUT_GPUS + PRM_GPUS must be <= NUM_GPUS" + echo "ACTOR_GPUS=${ACTOR_GPUS}, ROLLOUT_GPUS=${ROLLOUT_GPUS}, PRM_GPUS=${PRM_GPUS}, NUM_GPUS=${NUM_GPUS}" + exit 1 +fi + +export RAY_health_check_failure_threshold=20 +export RAY_health_check_period_ms=5000 +export RAY_health_check_timeout_ms=30000 +export RAY_num_heartbeats_timeout=60 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +SLIME_ROOT="$(cd -- "${SCRIPT_DIR}/../slime" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" +source "${SLIME_ROOT}/scripts/models/qwen3.5-4B.sh" + +HF_CKPT=${HF_CKPT:-${REPO_ROOT}/models/Qwen3.5-4B} +REF_LOAD=${REF_LOAD:-${HF_CKPT}} +SAVE_CKPT=${SAVE_CKPT:-${REPO_ROOT}/ckpt/qwen35-4b-openclaw-opd-topk} +PRM_MODEL_PATH=${PRM_MODEL_PATH:-${HF_CKPT}} + +export SGLANG_API_KEY="${SGLANG_API_KEY}" +export SERVED_MODEL_NAME="qwen3.5-4b" +export HOST="0.0.0.0" +export PORT="30000" +export OPENCLAW_RECORD_ENABLED="${OPENCLAW_RECORD_ENABLED:-1}" +export OPENCLAW_RECORD_FILE="${SCRIPT_DIR}/results/qwen35_4b_record.jsonl" +export TP="2" +export CONTEXT_LENGTH="32768" +export MEM_FRACTION_STATIC="0.85" +export REASONING_PARSER="${REASONING_PARSER:-qwen3}" +export TOOL_CALL_PARSER="${TOOL_CALL_PARSER:-qwen25}" +export SGLANG_LANGUAGE_ONLY="${SGLANG_LANGUAGE_ONLY:-1}" +export PRM_M="${PRM_M:-3}" +export OPENCLAW_OPD_TEACHER_LP_MAX_CONCURRENCY="${OPENCLAW_OPD_TEACHER_LP_MAX_CONCURRENCY:-3}" + +CKPT_ARGS=( + --megatron-to-hf-mode bridge + --hf-checkpoint "${HF_CKPT}" + --ref-load "${REF_LOAD}" + --save "${SAVE_CKPT}" + --save-interval 1 + --rotary-base 5000000 +) + +ROLLOUT_ARGS=( + --disable-rollout-global-dataset + --rollout-function-path openclaw_opd_rollout.generate_rollout_openclaw_opd + + --num-rollout 100000000 + --rollout-batch-size 4 + --n-samples-per-prompt 1 + --rollout-max-response-len 8192 + --rollout-max-context-len 32768 + --rollout-temperature 0.6 + --reward-key score + + --num-steps-per-rollout 1 +) + +PERF_ARGS=( + --tensor-model-parallel-size 4 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 32768 + --log-probs-chunk-size 1024 +) + +OPD_ARGS=( + --loss-type custom_loss + --custom-loss-function-path topk_distillation_loss.topk_distillation_loss_function + --distill-topk 50 + --disable-compute-advantages-and-returns + --entropy-coef 0.00 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +EVAL_ARGS=() + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 2 + --sglang-tool-call-parser "${TOOL_CALL_PARSER}" + --sglang-mem-fraction-static 0.85 + --sglang-context-length 32768 + --sglang-reasoning-parser "${REASONING_PARSER}" +) + +if [ "${SGLANG_LANGUAGE_ONLY}" = "1" ]; then + SGLANG_ARGS+=(--sglang-language-only) +fi + +PRM_ARGS=( + --prm-enable + --prm-num-gpus "${PRM_GPUS}" + --prm-num-gpus-per-engine 2 + --prm-model-path "${PRM_MODEL_PATH}" + --prm-m "${PRM_M}" + --prm-temperature "${PRM_TEMPERATURE:-0.6}" + --prm-max-new-tokens "${PRM_MAX_NEW_TOKENS:-8192}" +) + +CUSTOM_ARGS=( + --custom-generate-function-path openclaw_opd_api_server.generate + --custom-rm-path openclaw_opd_api_server.reward_func +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +USE_WANDB=${USE_WANDB:-1} +WANDB_PROJECT=${WANDB_PROJECT:-openclaw_rl} +WANDB_KEY_VALUE=${WANDB_KEY:-${WANDB_API_KEY:-}} +if [ "${USE_WANDB}" = "1" ] && [ -n "${WANDB_KEY_VALUE}" ]; then + WANDB_ARGS=( + --use-wandb + --wandb-project ${WANDB_PROJECT} + --wandb-group qwen35-4b-openclaw-opd-topk + --wandb-key ${WANDB_KEY_VALUE} + ) +else + WANDB_ARGS=() +fi + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export no_proxy="127.0.0.1,${MASTER_ADDR}" +ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"${REPO_ROOT}/Megatron-LM/:${SCRIPT_DIR}:${SLIME_ROOT}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"FLASHINFER_WORKSPACE_BASE\": \"${FLASHINFER_WORKSPACE_BASE}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 "${SLIME_ROOT}/train_async.py" \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node "${ACTOR_GPUS}" \ + --rollout-num-gpus "${ROLLOUT_GPUS}" \ + --num-gpus-per-node "${NUM_GPUS}" \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${OPD_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${CUSTOM_ARGS[@]} \ + ${PRM_ARGS[@]} diff --git a/openclaw-opd/run_qwen35_4b_openclaw_opd_topk_lora.sh b/openclaw-opd/run_qwen35_4b_openclaw_opd_topk_lora.sh new file mode 100644 index 0000000000..085e249e70 --- /dev/null +++ b/openclaw-opd/run_qwen35_4b_openclaw_opd_topk_lora.sh @@ -0,0 +1,195 @@ +#!/bin/bash +# Qwen3.5-4B LoRA training with OpenClaw OPD top-k distillation +# Fewer GPUs than full fine-tuning: default 4 GPUs (2 actor + 1 rollout + 1 PRM) +# Uses FSDP backend instead of Megatron TP + +SKIP_CLUSTER_CLEANUP=${SKIP_CLUSTER_CLEANUP:-0} +if [ "${SKIP_CLUSTER_CLEANUP}" != "1" ]; then + pkill -9 sglang + sleep 3 + ray stop --force + pkill -9 ray + pkill -9 python + sleep 3 + pkill -9 ray + pkill -9 python +fi + +set -ex + +# keep stdout/stderr unbuffered in ray jobs +export PYTHONUNBUFFERED=1 +export PYTHONFAULTHANDLER=1 +export FLASHINFER_WORKSPACE_BASE="${FLASHINFER_WORKSPACE_BASE:-/tmp}" + +NUM_GPUS=${NUM_GPUS:-4} +ACTOR_GPUS=${ACTOR_GPUS:-2} +ROLLOUT_GPUS=${ROLLOUT_GPUS:-1} +PRM_GPUS=${PRM_GPUS:-1} + +if (( ACTOR_GPUS + ROLLOUT_GPUS + PRM_GPUS > NUM_GPUS )); then + echo "ACTOR_GPUS + ROLLOUT_GPUS + PRM_GPUS must be <= NUM_GPUS" + echo "ACTOR_GPUS=${ACTOR_GPUS}, ROLLOUT_GPUS=${ROLLOUT_GPUS}, PRM_GPUS=${PRM_GPUS}, NUM_GPUS=${NUM_GPUS}" + exit 1 +fi + +export RAY_health_check_failure_threshold=20 +export RAY_health_check_period_ms=5000 +export RAY_health_check_timeout_ms=30000 +export RAY_num_heartbeats_timeout=60 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +SLIME_ROOT="$(cd -- "${SCRIPT_DIR}/../slime" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" + +HF_CKPT=${HF_CKPT:-${REPO_ROOT}/models/Qwen3.5-4B} +REF_LOAD=${REF_LOAD:-${HF_CKPT}} +SAVE_CKPT=${SAVE_CKPT:-${REPO_ROOT}/ckpt/qwen35-4b-openclaw-opd-topk-lora} +PRM_MODEL_PATH=${PRM_MODEL_PATH:-${HF_CKPT}} + +export SGLANG_API_KEY="${SGLANG_API_KEY}" +export SERVED_MODEL_NAME="qwen3.5-4b-lora" +export HOST="0.0.0.0" +export PORT="30000" +export OPENCLAW_RECORD_ENABLED="${OPENCLAW_RECORD_ENABLED:-1}" +export OPENCLAW_RECORD_FILE="${SCRIPT_DIR}/results/qwen35_4b_lora_record.jsonl" +export TP="${TP:-1}" +export CONTEXT_LENGTH="32768" +export MEM_FRACTION_STATIC="0.85" +export REASONING_PARSER="${REASONING_PARSER:-qwen3}" +export TOOL_CALL_PARSER="${TOOL_CALL_PARSER:-qwen25}" +export SGLANG_LANGUAGE_ONLY="${SGLANG_LANGUAGE_ONLY:-1}" +export PRM_M="${PRM_M:-3}" +export OPENCLAW_OPD_TEACHER_LP_MAX_CONCURRENCY="${OPENCLAW_OPD_TEACHER_LP_MAX_CONCURRENCY:-3}" + + +CKPT_ARGS=( + --hf-checkpoint "${HF_CKPT}" + --ref-load "${REF_LOAD}" + --save "${SAVE_CKPT}" + --save-interval 1 +) + +ROLLOUT_ARGS=( + --disable-rollout-global-dataset + --rollout-function-path openclaw_opd_rollout.generate_rollout_openclaw_opd + + --num-rollout 100000000 + --rollout-batch-size 4 + --n-samples-per-prompt 1 + --rollout-max-response-len 8192 + --rollout-max-context-len 32768 + --rollout-temperature 0.6 + --reward-key score + + --num-steps-per-rollout 1 +) + +PERF_ARGS=( + --use-dynamic-batch-size + --max-tokens-per-gpu "${MAX_TOKENS_PER_GPU:-8192}" + --gradient-checkpointing +) + +OPD_ARGS=( + --loss-type custom_loss + --custom-loss-function-path topk_distillation_loss.topk_distillation_loss_function + --distill-topk 50 + --disable-compute-advantages-and-returns + # OPD rewards are dummy 1.0 values; without this flag, _drop_constant_reward_groups + # treats each single-sample group as "constant" and drops all but one, causing an + # empty-partition crash when dp_size > 1 (e.g. FSDP with ACTOR_GPUS >= 2). + --disable-rewards-normalization + --entropy-coef 0.00 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +LORA_ARGS=( + --use-lora + --lora-rank 16 + --lora-alpha 32 + --lora-target-modules "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj" +) + +EVAL_ARGS=() + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine "${TP}" + --sglang-tool-call-parser "${TOOL_CALL_PARSER}" + --sglang-mem-fraction-static 0.85 + --sglang-context-length 32768 + --sglang-reasoning-parser "${REASONING_PARSER}" +) + +if [ "${SGLANG_LANGUAGE_ONLY}" = "1" ]; then + SGLANG_ARGS+=(--sglang-language-only) +fi + +PRM_ARGS=( + --prm-enable + --prm-num-gpus "${PRM_GPUS}" + --prm-num-gpus-per-engine "${PRM_TP:-${TP}}" + --prm-model-path "${PRM_MODEL_PATH}" + --prm-m "${PRM_M}" + --prm-temperature "${PRM_TEMPERATURE:-0.6}" + --prm-max-new-tokens "${PRM_MAX_NEW_TOKENS:-4096}" +) + +CUSTOM_ARGS=( + --custom-generate-function-path openclaw_opd_api_server.generate + --custom-rm-path openclaw_opd_api_server.reward_func +) + +USE_WANDB=${USE_WANDB:-1} +WANDB_PROJECT=${WANDB_PROJECT:-openclaw_rl} +WANDB_KEY_VALUE=${WANDB_KEY:-${WANDB_API_KEY:-}} +if [ "${USE_WANDB}" = "1" ] && [ -n "${WANDB_KEY_VALUE}" ]; then + WANDB_ARGS=( + --use-wandb + --wandb-project ${WANDB_PROJECT} + --wandb-group qwen35-4b-openclaw-opd-topk-lora + --wandb-key ${WANDB_KEY_VALUE} + ) +else + WANDB_ARGS=() +fi + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export no_proxy="127.0.0.1,${MASTER_ADDR}" +ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"${SCRIPT_DIR}:${SLIME_ROOT}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"FLASHINFER_WORKSPACE_BASE\": \"${FLASHINFER_WORKSPACE_BASE}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 "${SLIME_ROOT}/train_async.py" \ + --train-backend fsdp \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node "${ACTOR_GPUS}" \ + --rollout-num-gpus "${ROLLOUT_GPUS}" \ + --num-gpus-per-node "${NUM_GPUS}" \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${OPD_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${CUSTOM_ARGS[@]} \ + ${PRM_ARGS[@]} \ + ${LORA_ARGS[@]} diff --git a/openclaw-rl/README.md b/openclaw-rl/README.md index fa40fae9ea..0f37441f46 100644 --- a/openclaw-rl/README.md +++ b/openclaw-rl/README.md @@ -42,7 +42,10 @@ where $\beta_{\text{KL}} = 0.02$. Entropy bonus is disabled ($\beta_{\text{ent}} ```bash cd slime +# Qwen3 bash ../openclaw-rl/run_qwen3_4b_openclaw_rl.sh +# Qwen3.5 +bash ../openclaw-rl/run_qwen35_4b_openclaw_rl.sh ``` @@ -52,8 +55,9 @@ bash ../openclaw-rl/run_qwen3_4b_openclaw_rl.sh ``` openclaw-rl/ ├── README.md -├── run_qwen3_4b_openclaw_rl.sh # Launch script -├── openclaw_api_server.py # FastAPI proxy + PRM scoring + sample submission -├── openclaw_rollout.py # Async rollout worker (bridges API server ↔ SLIME trainer) -└── results/ # Runtime records (auto-created) +├── run_qwen3_4b_openclaw_rl.sh # Launch script (Qwen3) +├── run_qwen35_4b_openclaw_rl.sh # Launch script (Qwen3.5) +├── openclaw_api_server.py # FastAPI proxy + PRM scoring + sample submission +├── openclaw_rollout.py # Async rollout worker (bridges API server ↔ SLIME trainer) +└── results/ # Runtime records (auto-created) ``` diff --git a/openclaw-rl/run_qwen35_4b_openclaw_rl.sh b/openclaw-rl/run_qwen35_4b_openclaw_rl.sh new file mode 100644 index 0000000000..9f8639ee1c --- /dev/null +++ b/openclaw-rl/run_qwen35_4b_openclaw_rl.sh @@ -0,0 +1,210 @@ +#!/bin/bash +# Qwen3.5-4B full fine-tuning with OpenClaw RL (Binary RL / GRPO) +# Uses Megatron backend with tensor parallelism + +SKIP_CLUSTER_CLEANUP=${SKIP_CLUSTER_CLEANUP:-0} +if [ "${SKIP_CLUSTER_CLEANUP}" != "1" ]; then + pkill -9 sglang + sleep 3 + ray stop --force + pkill -9 ray + pkill -9 python + sleep 3 + pkill -9 ray + pkill -9 python +fi + +set -ex + +# keep stdout/stderr unbuffered in ray jobs +export PYTHONUNBUFFERED=1 +export PYTHONFAULTHANDLER=1 +export FLASHINFER_WORKSPACE_BASE="${FLASHINFER_WORKSPACE_BASE:-/tmp}" + +NUM_GPUS=${NUM_GPUS:-8} +ACTOR_GPUS=${ACTOR_GPUS:-4} +ROLLOUT_GPUS=${ROLLOUT_GPUS:-2} +PRM_GPUS=${PRM_GPUS:-2} + +if (( ACTOR_GPUS + ROLLOUT_GPUS + PRM_GPUS > NUM_GPUS )); then + echo "ACTOR_GPUS + ROLLOUT_GPUS + PRM_GPUS must be <= NUM_GPUS" + echo "ACTOR_GPUS=${ACTOR_GPUS}, ROLLOUT_GPUS=${ROLLOUT_GPUS}, PRM_GPUS=${PRM_GPUS}, NUM_GPUS=${NUM_GPUS}" + exit 1 +fi + +export RAY_health_check_failure_threshold=20 +export RAY_health_check_period_ms=5000 +export RAY_health_check_timeout_ms=30000 +export RAY_num_heartbeats_timeout=60 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +SLIME_ROOT="$(cd -- "${SCRIPT_DIR}/../slime" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" +source "${SLIME_ROOT}/scripts/models/qwen3.5-4B.sh" + +HF_CKPT=${HF_CKPT:-${REPO_ROOT}/models/Qwen3.5-4B} +REF_LOAD=${REF_LOAD:-${HF_CKPT}} +SAVE_CKPT=${SAVE_CKPT:-${REPO_ROOT}/ckpt/qwen35-4b-openclaw-rl} +PRM_MODEL_PATH=${PRM_MODEL_PATH:-${HF_CKPT}} + +export SGLANG_API_KEY="${SGLANG_API_KEY}" +export SERVED_MODEL_NAME="qwen3.5-4b" +export HOST="0.0.0.0" +export PORT="30000" +export OPENCLAW_RECORD_ENABLED="${OPENCLAW_RECORD_ENABLED:-1}" +export OPENCLAW_RECORD_FILE="${SCRIPT_DIR}/results/qwen35_4b_record.jsonl" +export TP="2" +export CONTEXT_LENGTH="32768" +export MEM_FRACTION_STATIC="0.85" +export REASONING_PARSER="${REASONING_PARSER:-qwen3}" +export TOOL_CALL_PARSER="${TOOL_CALL_PARSER:-qwen25}" +export SGLANG_LANGUAGE_ONLY="${SGLANG_LANGUAGE_ONLY:-1}" +export PRM_M="${PRM_M:-3}" + + +CKPT_ARGS=( + --megatron-to-hf-mode bridge + --hf-checkpoint "${HF_CKPT}" + --ref-load "${REF_LOAD}" + --save "${SAVE_CKPT}" + --save-interval 100 + --rotary-base 5000000 +) + +ROLLOUT_ARGS=( + --disable-rollout-global-dataset + --rollout-function-path openclaw_rollout.generate_rollout_openclaw + + --num-rollout 100000000 + --rollout-batch-size 16 + --n-samples-per-prompt 1 + --rollout-max-response-len 8192 + --rollout-max-context-len 32768 + --rollout-temperature 0.6 + --reward-key score + + --num-steps-per-rollout 1 +) + +PERF_ARGS=( + --tensor-model-parallel-size 4 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 32768 + --log-probs-chunk-size 1024 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --disable-rewards-normalization + --use-kl-loss + --kl-loss-coef 0.0 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-5 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +EVAL_ARGS=() + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 2 + --sglang-tool-call-parser "${TOOL_CALL_PARSER}" + --sglang-mem-fraction-static 0.85 + --sglang-context-length 32768 + --sglang-reasoning-parser "${REASONING_PARSER}" +) + +if [ "${SGLANG_LANGUAGE_ONLY}" = "1" ]; then + SGLANG_ARGS+=(--sglang-language-only) +fi + +PRM_ARGS=( + --prm-enable + --prm-num-gpus "${PRM_GPUS}" + --prm-num-gpus-per-engine 2 + --prm-model-path "${PRM_MODEL_PATH}" + --prm-m "${PRM_M}" + --prm-temperature "${PRM_TEMPERATURE:-0.6}" + --prm-max-new-tokens "${PRM_MAX_NEW_TOKENS:-4096}" +) + +CUSTOM_ARGS=( + --custom-generate-function-path openclaw_api_server.generate + --custom-rm-path openclaw_api_server.reward_func +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +USE_WANDB=${USE_WANDB:-1} +WANDB_PROJECT=${WANDB_PROJECT:-openclaw_rl} +WANDB_KEY_VALUE=${WANDB_KEY:-${WANDB_API_KEY:-}} +if [ "${USE_WANDB}" = "1" ] && [ -n "${WANDB_KEY_VALUE}" ]; then + WANDB_ARGS=( + --use-wandb + --wandb-project ${WANDB_PROJECT} + --wandb-group qwen35-4b-openclaw-rl + --wandb-key ${WANDB_KEY_VALUE} + ) +else + WANDB_ARGS=() +fi + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export no_proxy="127.0.0.1,${MASTER_ADDR}" +ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"${REPO_ROOT}/Megatron-LM/:${SCRIPT_DIR}:${SLIME_ROOT}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"FLASHINFER_WORKSPACE_BASE\": \"${FLASHINFER_WORKSPACE_BASE}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 "${SLIME_ROOT}/train_async.py" \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node "${ACTOR_GPUS}" \ + --rollout-num-gpus "${ROLLOUT_GPUS}" \ + --num-gpus-per-node "${NUM_GPUS}" \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${CUSTOM_ARGS[@]} \ + ${PRM_ARGS[@]} diff --git a/openclaw-rl/run_qwen35_4b_openclaw_rl_lora.sh b/openclaw-rl/run_qwen35_4b_openclaw_rl_lora.sh new file mode 100644 index 0000000000..03ffc82ff5 --- /dev/null +++ b/openclaw-rl/run_qwen35_4b_openclaw_rl_lora.sh @@ -0,0 +1,193 @@ +#!/bin/bash +# Qwen3.5-4B LoRA training with OpenClaw RL (Binary RL / GRPO) +# Fewer GPUs than full fine-tuning: default 4 GPUs (2 actor + 1 rollout + 1 PRM) +# Uses FSDP backend instead of Megatron TP + +SKIP_CLUSTER_CLEANUP=${SKIP_CLUSTER_CLEANUP:-0} +if [ "${SKIP_CLUSTER_CLEANUP}" != "1" ]; then + pkill -9 sglang + sleep 3 + ray stop --force + pkill -9 ray + pkill -9 python + sleep 3 + pkill -9 ray + pkill -9 python +fi + +set -ex + +# keep stdout/stderr unbuffered in ray jobs +export PYTHONUNBUFFERED=1 +export PYTHONFAULTHANDLER=1 +export FLASHINFER_WORKSPACE_BASE="${FLASHINFER_WORKSPACE_BASE:-/tmp}" + +NUM_GPUS=${NUM_GPUS:-4} +ACTOR_GPUS=${ACTOR_GPUS:-2} +ROLLOUT_GPUS=${ROLLOUT_GPUS:-1} +PRM_GPUS=${PRM_GPUS:-1} + +if (( ACTOR_GPUS + ROLLOUT_GPUS + PRM_GPUS > NUM_GPUS )); then + echo "ACTOR_GPUS + ROLLOUT_GPUS + PRM_GPUS must be <= NUM_GPUS" + echo "ACTOR_GPUS=${ACTOR_GPUS}, ROLLOUT_GPUS=${ROLLOUT_GPUS}, PRM_GPUS=${PRM_GPUS}, NUM_GPUS=${NUM_GPUS}" + exit 1 +fi + +export RAY_health_check_failure_threshold=20 +export RAY_health_check_period_ms=5000 +export RAY_health_check_timeout_ms=30000 +export RAY_num_heartbeats_timeout=60 + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +SLIME_ROOT="$(cd -- "${SCRIPT_DIR}/../slime" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" + +HF_CKPT=${HF_CKPT:-${REPO_ROOT}/models/Qwen3.5-4B} +REF_LOAD=${REF_LOAD:-${HF_CKPT}} +SAVE_CKPT=${SAVE_CKPT:-${REPO_ROOT}/ckpt/qwen35-4b-openclaw-rl-lora} +PRM_MODEL_PATH=${PRM_MODEL_PATH:-${HF_CKPT}} + +export SGLANG_API_KEY="${SGLANG_API_KEY}" +export SERVED_MODEL_NAME="qwen3.5-4b-lora" +export HOST="0.0.0.0" +export PORT="30000" +export OPENCLAW_RECORD_ENABLED="${OPENCLAW_RECORD_ENABLED:-1}" +export OPENCLAW_RECORD_FILE="${SCRIPT_DIR}/results/qwen35_4b_lora_record.jsonl" +export TP="${TP:-1}" +export CONTEXT_LENGTH="32768" +export MEM_FRACTION_STATIC="0.85" +export REASONING_PARSER="${REASONING_PARSER:-qwen3}" +export TOOL_CALL_PARSER="${TOOL_CALL_PARSER:-qwen25}" +export SGLANG_LANGUAGE_ONLY="${SGLANG_LANGUAGE_ONLY:-1}" +export PRM_M="${PRM_M:-3}" + + +CKPT_ARGS=( + --hf-checkpoint "${HF_CKPT}" + --ref-load "${REF_LOAD}" + --save "${SAVE_CKPT}" + --save-interval 1 +) + +ROLLOUT_ARGS=( + --disable-rollout-global-dataset + --rollout-function-path openclaw_rollout.generate_rollout_openclaw + + --num-rollout 100000000 + --rollout-batch-size 16 + --n-samples-per-prompt 1 + --rollout-max-response-len 8192 + --rollout-max-context-len 32768 + --rollout-temperature 0.6 + --reward-key score + + --num-steps-per-rollout 1 +) + +PERF_ARGS=( + --use-dynamic-batch-size + --max-tokens-per-gpu "${MAX_TOKENS_PER_GPU:-8192}" + --gradient-checkpointing +) + +GRPO_ARGS=( + --advantage-estimator grpo + --disable-rewards-normalization + --use-kl-loss + --kl-loss-coef 0.0 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-5 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +LORA_ARGS=( + --use-lora + --lora-rank 16 + --lora-alpha 32 + --lora-target-modules "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj" +) + +EVAL_ARGS=() + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine "${TP}" + --sglang-tool-call-parser "${TOOL_CALL_PARSER}" + --sglang-mem-fraction-static 0.85 + --sglang-context-length 32768 + --sglang-reasoning-parser "${REASONING_PARSER}" +) + +if [ "${SGLANG_LANGUAGE_ONLY}" = "1" ]; then + SGLANG_ARGS+=(--sglang-language-only) +fi + +PRM_ARGS=( + --prm-enable + --prm-num-gpus "${PRM_GPUS}" + --prm-num-gpus-per-engine "${PRM_TP:-${TP}}" + --prm-model-path "${PRM_MODEL_PATH}" + --prm-m "${PRM_M}" + --prm-temperature "${PRM_TEMPERATURE:-0.6}" + --prm-max-new-tokens "${PRM_MAX_NEW_TOKENS:-4096}" +) + +CUSTOM_ARGS=( + --custom-generate-function-path openclaw_api_server.generate + --custom-rm-path openclaw_api_server.reward_func +) + +USE_WANDB=${USE_WANDB:-1} +WANDB_PROJECT=${WANDB_PROJECT:-openclaw_rl} +WANDB_KEY_VALUE=${WANDB_KEY:-${WANDB_API_KEY:-}} +if [ "${USE_WANDB}" = "1" ] && [ -n "${WANDB_KEY_VALUE}" ]; then + WANDB_ARGS=( + --use-wandb + --wandb-project ${WANDB_PROJECT} + --wandb-group qwen35-4b-openclaw-rl-lora + --wandb-key ${WANDB_KEY_VALUE} + ) +else + WANDB_ARGS=() +fi + +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export no_proxy="127.0.0.1,${MASTER_ADDR}" +ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"${SCRIPT_DIR}:${SLIME_ROOT}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"FLASHINFER_WORKSPACE_BASE\": \"${FLASHINFER_WORKSPACE_BASE}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 "${SLIME_ROOT}/train_async.py" \ + --train-backend fsdp \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node "${ACTOR_GPUS}" \ + --rollout-num-gpus "${ROLLOUT_GPUS}" \ + --num-gpus-per-node "${NUM_GPUS}" \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${CUSTOM_ARGS[@]} \ + ${PRM_ARGS[@]} \ + ${LORA_ARGS[@]} diff --git a/openclaw-test/README.md b/openclaw-test/README.md index 5b949ae9c4..3f5fd02aec 100644 --- a/openclaw-test/README.md +++ b/openclaw-test/README.md @@ -115,32 +115,38 @@ cd slime **Combined (RL + OPD):** ```bash -bash ../openclaw-combine/run_qwen3_4b_openclaw_combine.sh +bash ../openclaw-combine/run_qwen3_4b_openclaw_combine.sh # Qwen3 +bash ../openclaw-combine/run_qwen35_4b_openclaw_combine.sh # Qwen3.5 ``` **Combined with LoRA** (parameter-efficient, fewer GPUs): ```bash -bash ../openclaw-combine/run_qwen3_4b_openclaw_combine_lora.sh +bash ../openclaw-combine/run_qwen3_4b_openclaw_combine_lora.sh # Qwen3 +bash ../openclaw-combine/run_qwen35_4b_openclaw_combine_lora.sh # Qwen3.5 ``` **Binary RL:** ```bash -bash ../openclaw-rl/run_qwen3_4b_openclaw_rl.sh +bash ../openclaw-rl/run_qwen3_4b_openclaw_rl.sh # Qwen3 +bash ../openclaw-rl/run_qwen35_4b_openclaw_rl.sh # Qwen3.5 ``` **Binary RL with LoRA** (parameter-efficient, fewer GPUs): ```bash -bash ../openclaw-rl/run_qwen3_4b_openclaw_rl_lora.sh +bash ../openclaw-rl/run_qwen3_4b_openclaw_rl_lora.sh # Qwen3 +bash ../openclaw-rl/run_qwen35_4b_openclaw_rl_lora.sh # Qwen3.5 ``` **On-Policy Distillation (OPD):** ```bash -bash ../openclaw-opd/run_qwen3_4b_openclaw_opd.sh +bash ../openclaw-opd/run_qwen3_4b_openclaw_opd.sh # Qwen3 +bash ../openclaw-opd/run_qwen35_4b_openclaw_opd.sh # Qwen3.5 ``` **OPD with LoRA** (parameter-efficient, fewer GPUs): ```bash -bash ../openclaw-opd/run_qwen3_4b_openclaw_opd_topk_lora.sh +bash ../openclaw-opd/run_qwen3_4b_openclaw_opd_topk_lora.sh # Qwen3 +bash ../openclaw-opd/run_qwen35_4b_openclaw_opd_topk_lora.sh # Qwen3.5 ``` > **Eval mode:** To enable evaluation logging with W&B, set `OPENCLAW_EVAL_MODE=1` and provide your W&B key via `WANDB_KEY` before launching. This is already the default in the OPD and Combine scripts. diff --git a/slime/README.md b/slime/README.md index f2fe743c66..cad8d20c04 100644 --- a/slime/README.md +++ b/slime/README.md @@ -11,7 +11,7 @@ 2. **Flexible Data Generation**: Enables arbitrary training data generation workflows through custom data generation interfaces and server-based engines. slime is the RL-framework behind [GLM-4.7](https://z.ai/blog/glm-4.7), [GLM-4.6](https://z.ai/blog/glm-4.6), [GLM-4.5](https://z.ai/blog/glm-4.5) and apart from models from Z.ai, we also supports the following models: -- Qwen3 series (Qwen3Next, Qwen3MoE, Qwen3), Qwen2.5 series; +- Qwen3.5 series, Qwen3 series (Qwen3Next, Qwen3MoE, Qwen3), Qwen2.5 series; - DeepSeek V3 series (DeepSeek V3, V3.1, DeepSeek R1); - Llama 3. diff --git a/slime/scripts/models/qwen3.5-4B.sh b/slime/scripts/models/qwen3.5-4B.sh new file mode 100644 index 0000000000..bb0b614155 --- /dev/null +++ b/slime/scripts/models/qwen3.5-4B.sh @@ -0,0 +1,25 @@ +MODEL_ARGS=( + --spec "slime_plugins.models.qwen3_5" "get_qwen3_5_spec" + + --disable-bias-linear + --qk-layernorm + --group-query-attention + --num-attention-heads 16 + --num-query-groups 4 + --kv-channels 256 + --num-layers 32 + --hidden-size 2560 + --ffn-hidden-size 9216 + --use-gated-attention + --attention-output-gate + + --normalization RMSNorm + --apply-layernorm-1p + --position-embedding-type rope + --norm-epsilon 1e-6 + --rotary-percent 0.25 + --swiglu + --vocab-size 248320 + + --rotary-base "${MODEL_ARGS_ROTARY_BASE:-10000000}" +) diff --git a/slime/slime/backends/fsdp_utils/actor.py b/slime/slime/backends/fsdp_utils/actor.py index c37395d446..3c5c271748 100644 --- a/slime/slime/backends/fsdp_utils/actor.py +++ b/slime/slime/backends/fsdp_utils/actor.py @@ -997,7 +997,7 @@ def apply_fsdp2(model, mesh=None, cpu_offload=False, args=None): if not layer_cls_to_wrap and hasattr(model, "base_model"): inner = getattr(model.base_model, "model", model.base_model) layer_cls_to_wrap = getattr(inner, "_no_split_modules", None) - assert layer_cls_to_wrap and len(layer_cls_to_wrap) > 0 and layer_cls_to_wrap[0] is not None + assert layer_cls_to_wrap and len(layer_cls_to_wrap) > 0 and next(iter(layer_cls_to_wrap)) is not None modules = [ module diff --git a/slime/slime/backends/megatron_utils/megatron_to_hf/__init__.py b/slime/slime/backends/megatron_utils/megatron_to_hf/__init__.py index 28af98ca48..a897ebbc0d 100644 --- a/slime/slime/backends/megatron_utils/megatron_to_hf/__init__.py +++ b/slime/slime/backends/megatron_utils/megatron_to_hf/__init__.py @@ -5,6 +5,7 @@ from .mimo import convert_mimo_to_hf from .processors import quantize_params, remove_padding from .qwen2 import convert_qwen2_to_hf +from .qwen3_5 import convert_qwen3_5_to_hf from .qwen3_next import convert_qwen3_next_to_hf from .qwen3_vl import convert_qwen3vl_to_hf from .qwen3moe import convert_qwen3moe_to_hf @@ -40,6 +41,8 @@ def _convert_to_hf_core(args, model_name, name, param): converted_named_tensors = convert_glm4_to_hf(args, name, param) elif "qwen3moe" in model_name: converted_named_tensors = convert_qwen3moe_to_hf(args, name, param) + elif "qwen3_5" in model_name: + converted_named_tensors = convert_qwen3_5_to_hf(args, name, param) elif "qwen3next" in model_name: converted_named_tensors = convert_qwen3_next_to_hf(args, name, param) elif "qwen3vl" in model_name: diff --git a/slime/slime/backends/megatron_utils/megatron_to_hf/qwen3_5.py b/slime/slime/backends/megatron_utils/megatron_to_hf/qwen3_5.py new file mode 100644 index 0000000000..36bbe35962 --- /dev/null +++ b/slime/slime/backends/megatron_utils/megatron_to_hf/qwen3_5.py @@ -0,0 +1,136 @@ +import re + +import torch + + +def _convert_mtp_layer(args, name, param, layer_idx): + if "enorm.weight" in name: + return [("mtp.pre_fc_norm_embedding.weight", param)] + if "hnorm.weight" in name: + return [("mtp.pre_fc_norm_hidden.weight", param)] + if "final_layernorm.weight" in name: + return [("mtp.norm.weight", param)] + if "eh_proj.weight" in name: + if param.dim() < 2: + raise ValueError(f"eh_proj weight expects 2D tensor, got {param.shape}") + first_half, second_half = param.chunk(2, dim=1) + new_param = torch.cat([second_half, first_half], dim=1) + return [("mtp.fc.weight", new_param)] + + if "transformer_layer" in name: + proxy_name = name.replace(f"mtp.layers.{layer_idx}.transformer_layer", f"decoder.layers.{layer_idx}") + mapped_params = convert_qwen3_5_to_hf(args, proxy_name, param) + + final_params = [] + for hf_name, tensor in mapped_params: + target_prefix = f"mtp.layers.{layer_idx}" + if f"model.layers.{layer_idx}" in hf_name: + new_hf_name = hf_name.replace(f"model.layers.{layer_idx}", target_prefix) + final_params.append((new_hf_name, tensor)) + else: + final_params.append((hf_name, tensor)) + return final_params + + return None + + +def convert_qwen3_5_to_hf(args, name, param): + if "mtp.layers" in name: + parts = name.split(".") + try: + layer_idx_loc = parts.index("layers") + 1 + layer_idx = parts[layer_idx_loc] + except (ValueError, IndexError) as e: + raise ValueError(f"Invalid MTP layer name format: {name}") from e + + result = _convert_mtp_layer(args, name, param, layer_idx) + if result is not None: + return result + + if name == "module.module.embedding.word_embeddings.weight": + return [("model.embed_tokens.weight", param)] + if name == "module.module.output_layer.weight": + return [("lm_head.weight", param)] + if name == "module.module.decoder.final_layernorm.weight": + return [("model.norm.weight", param)] + + try: + head_dim = args.kv_channels if args.kv_channels is not None else args.hidden_size // args.num_attention_heads + except AttributeError: + head_dim = args.hidden_size // args.num_attention_heads + value_num_per_group = args.num_attention_heads // args.num_query_groups + + decoder_layers_pattern = r"module\.module\.decoder\.layers\.(\d+)\.(.+)" + match = re.match(decoder_layers_pattern, name) + if match: + layer_idx, rest = match.groups() + + if rest == "self_attention.linear_proj.weight": + return [(f"model.layers.{layer_idx}.self_attn.o_proj.weight", param)] + elif rest == "self_attention.linear_qkv.weight": + param = param.view(args.num_query_groups, -1, head_dim, args.hidden_size) + q_param, k_param, v_param = torch.split( + param, split_size_or_sections=[2 * value_num_per_group, 1, 1], dim=1 + ) + q_param = ( + q_param.reshape(args.num_query_groups, 2, value_num_per_group, head_dim, args.hidden_size) + .transpose(1, 2) + .reshape(-1, args.hidden_size) + ) + k_param = k_param.reshape(-1, args.hidden_size) + v_param = v_param.reshape(-1, args.hidden_size) + return [ + (f"model.layers.{layer_idx}.self_attn.q_proj.weight", q_param), + (f"model.layers.{layer_idx}.self_attn.k_proj.weight", k_param), + (f"model.layers.{layer_idx}.self_attn.v_proj.weight", v_param), + ] + elif rest == "self_attention.linear_qkv.bias": + param = param.view(args.num_query_groups, -1) + q_bias, k_bias, v_bias = torch.split( + param, + split_size_or_sections=[value_num_per_group * head_dim * 2, head_dim, head_dim], + dim=1, + ) + q_bias = q_bias.contiguous().flatten() + k_bias = k_bias.contiguous().flatten() + v_bias = v_bias.contiguous().flatten() + return [ + (f"model.layers.{layer_idx}.self_attn.q_proj.bias", q_bias), + (f"model.layers.{layer_idx}.self_attn.k_proj.bias", k_bias), + (f"model.layers.{layer_idx}.self_attn.v_proj.bias", v_bias), + ] + elif rest == "mlp.linear_fc1.weight": + gate_weight, up_weight = param.chunk(2, dim=0) + return [ + (f"model.layers.{layer_idx}.mlp.gate_proj.weight", gate_weight), + (f"model.layers.{layer_idx}.mlp.up_proj.weight", up_weight), + ] + elif rest == "mlp.linear_fc2.weight": + return [(f"model.layers.{layer_idx}.mlp.down_proj.weight", param)] + elif rest == "self_attention.linear_qkv.layer_norm_weight": + return [(f"model.layers.{layer_idx}.input_layernorm.weight", param)] + elif rest == "mlp.linear_fc1.layer_norm_weight": + return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)] + elif rest == "pre_mlp_layernorm.weight": + return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)] + + elif rest == "self_attention.q_layernorm.weight": + return [(f"model.layers.{layer_idx}.self_attn.q_norm.weight", param)] + elif rest == "self_attention.k_layernorm.weight": + return [(f"model.layers.{layer_idx}.self_attn.k_norm.weight", param)] + elif rest.startswith("self_attention.") and rest[len("self_attention.") :] in [ + "input_layernorm.weight", + "linear_attn.A_log", + "linear_attn.conv1d.weight", + "linear_attn.dt_bias", + "linear_attn.in_proj_a.weight", + "linear_attn.in_proj_b.weight", + "linear_attn.in_proj_qkv.weight", + "linear_attn.in_proj_z.weight", + "linear_attn.norm.weight", + "linear_attn.out_proj.weight", + ]: + rest = rest[len("self_attention.") :] + return [(f"model.layers.{layer_idx}.{rest}", param)] + + raise ValueError(f"Unknown parameter name: {name}") diff --git a/slime/slime/backends/megatron_utils/model_provider.py b/slime/slime/backends/megatron_utils/model_provider.py index 31db8b0da8..7d49136f85 100644 --- a/slime/slime/backends/megatron_utils/model_provider.py +++ b/slime/slime/backends/megatron_utils/model_provider.py @@ -83,6 +83,8 @@ def wrapped_model_provider( if args.megatron_to_hf_mode == "bridge": from megatron.bridge import AutoBridge + import slime_plugins.megatron_bridge # noqa: F401 # register custom bridges + bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) provider = bridge.to_megatron_provider(load_weights=False) # TODO: we should not manually set this... diff --git a/slime/slime/backends/sglang_utils/qwen3_5.py b/slime/slime/backends/sglang_utils/qwen3_5.py new file mode 100644 index 0000000000..f6d7db28fe --- /dev/null +++ b/slime/slime/backends/sglang_utils/qwen3_5.py @@ -0,0 +1,382 @@ +import hashlib +import json +import logging +import os +import shutil +import tempfile +from pathlib import Path + +from transformers import AutoConfig + +logger = logging.getLogger(__name__) + + +def is_qwen35_model_path(model_path: str) -> bool: + try: + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + except Exception: + return False + return getattr(hf_config, "model_type", None) in {"qwen3_5", "qwen3_5_text", "qwen3_5_moe_text"} + + +def maybe_prepare_qwen35_text_model(model_path: str, *, language_only: bool) -> str: + if not language_only: + return model_path + + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + if getattr(hf_config, "model_type", None) != "qwen3_5" or not hasattr(hf_config, "text_config"): + return model_path + + target_dir = _get_shadow_model_dir(model_path) + config_path = target_dir / "config.json" + if config_path.exists(): + return str(target_dir) + + target_dir.parent.mkdir(parents=True, exist_ok=True) + temp_dir = Path(tempfile.mkdtemp(prefix=target_dir.name + ".", dir=target_dir.parent)) + try: + _populate_shadow_model_dir(source_dir=Path(model_path), target_dir=temp_dir, hf_config=hf_config) + os.replace(temp_dir, target_dir) + except FileExistsError: + shutil.rmtree(temp_dir, ignore_errors=True) + except Exception: + shutil.rmtree(temp_dir, ignore_errors=True) + raise + + logger.info("Prepared Qwen3.5 text-only shadow model at %s", target_dir) + return str(target_dir) + + +_qwen35_patched = False + + +def patch_sglang_qwen35() -> None: + global _qwen35_patched + if _qwen35_patched: + return + _qwen35_patched = True + import torch + import torch.nn as nn + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig, Qwen3_5VisionConfig + + from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape, mamba2_state_dtype + from sglang.srt.configs.update_config import adjust_tp_num_heads_if_necessary + from sglang.srt.distributed import get_pp_group + from sglang.srt.layers.logits_processor import LogitsProcessor + from sglang.srt.layers.utils.common import PPMissingLayer + from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead + from sglang.srt.layers.dp_attention import get_attention_tp_size + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.model_loader.weight_utils import default_weight_loader + from sglang.srt.models import registry as registry_module + from sglang.srt.models import qwen3_5 as qwen3_5_model + from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import add_prefix + from sglang.srt.utils import is_cpu + + original_qwen35_dense_cls = qwen3_5_model.Qwen3_5ForCausalLM + + class PatchedQwen35ForCausalLM(nn.Module): + def __init__( + self, + config: Qwen3_5TextConfig, + quant_config=None, + prefix: str = "", + ) -> None: + super().__init__() + self.pp_group = get_pp_group() + self.config = config + self.quant_config = quant_config + self.model = original_qwen35_dense_cls( + config=config, + quant_config=quant_config, + prefix=add_prefix("model", prefix), + ) + + if self.pp_group.is_last_rank: + if self.pp_group.world_size == 1 and config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + org_num_embeddings=config.vocab_size, + prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + ) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config) + self.capture_aux_hidden_states = False + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + @property + def start_layer(self): + return self.model.layers.start_layer + + @property + def end_layer(self): + return self.model.layers.end_layer + + @torch.no_grad() + def forward( + self, + input_ids, + positions, + forward_batch, + input_embeds=None, + pp_proxy_tensors=None, + input_deepstack_embeds=None, + **kwargs, + ): + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + input_embeds=input_embeds, + pp_proxy_tensors=pp_proxy_tensors, + input_deepstack_embeds=input_deepstack_embeds, + ) + + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + aux_hidden_states, + ) + return hidden_states + + def load_weights(self, weights): + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name or "mtp" in name or "visual" in name: + continue + if "language_model" in name: + name = name.replace(r"model.language_model.", r"model.") + if ".self_attn." in name: + name = name.replace(".self_attn", "") + if not name.startswith("model.") and ( + name.startswith("layers.") + or name.startswith("embed_tokens.") + or name.startswith("norm.") + ): + name = add_prefix(name, "model") + + if name == "model.embed_tokens.weight": + if self.pp_group.is_last_rank and self.config.tie_word_embeddings: + lm_head_weight = params_dict.get("lm_head.weight") + if lm_head_weight is not None: + weight_loader = getattr(lm_head_weight, "weight_loader", default_weight_loader) + weight_loader(lm_head_weight, loaded_weight) + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + mapped_name = name.replace(weight_name, param_name) + if mapped_name.endswith(".bias") and mapped_name not in params_dict: + continue + if mapped_name not in params_dict: + continue + param = params_dict[mapped_name] + weight_loader = getattr(param, "weight_loader") + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + logger.warning("Parameter %s not found in params_dict", name) + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + PatchedQwen35ForCausalLM.__name__ = "Qwen3_5ForCausalLM" + qwen3_5_model.Qwen3_5ForCausalLM = PatchedQwen35ForCausalLM + + entry_classes = [ + qwen3_5_model.Qwen3_5MoeForConditionalGeneration, + qwen3_5_model.Qwen3_5ForConditionalGeneration, + qwen3_5_model.Qwen3_5MoeForCausalLM, + qwen3_5_model.Qwen3_5ForCausalLM, + ] + deduped = [] + seen = set() + for cls in entry_classes: + if cls not in seen: + deduped.append(cls) + seen.add(cls) + qwen3_5_model.EntryClass = deduped + + def _get_model_config_for_expert_location(cls, config): + text_config = getattr(config, "text_config", config) + num_experts = getattr(text_config, "num_experts", None) + if not num_experts: + return None + return qwen3_5_model.ModelConfigForExpertLocation( + num_layers=text_config.num_hidden_layers, + num_logical_experts=num_experts, + num_groups=None, + ) + + for cls in [ + qwen3_5_model.Qwen3_5ForCausalLM, + qwen3_5_model.Qwen3_5MoeForCausalLM, + qwen3_5_model.Qwen3_5ForConditionalGeneration, + qwen3_5_model.Qwen3_5MoeForConditionalGeneration, + ]: + cls.get_model_config_for_expert_location = classmethod(_get_model_config_for_expert_location) + + registry_module.import_model_classes.cache_clear() + for cls in deduped: + registry_module.ModelRegistry.models[cls.__name__] = cls + + _is_cpu = is_cpu() + + def _layers_block_type(self): + layer_types = getattr(self, "layer_types", None) or [] + return [ + "attention" if layer_type == "full_attention" else layer_type + for layer_type in layer_types + ] + + def _linear_layer_ids(self): + return [ + idx + for idx, layer_type in enumerate(self.layers_block_type) + if layer_type == "linear_attention" + ] + + def _full_attention_layer_ids(self): + return [ + idx + for idx, layer_type in enumerate(self.layers_block_type) + if layer_type == "attention" + ] + + def _mamba2_cache_params(self): + if _is_cpu: + world_size = get_attention_tp_size() + adjust_tp_num_heads_if_necessary(self, world_size, False) + + shape = Mamba2StateShape.create( + tp_world_size=get_attention_tp_size(), + intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads, + n_groups=self.linear_num_key_heads, + num_heads=self.linear_num_value_heads, + head_dim=self.linear_value_head_dim, + state_size=self.linear_key_head_dim, + conv_kernel=self.linear_conv_kernel_dim, + ) + + return Mamba2CacheParams( + shape=shape, layers=self.linear_layer_ids, dtype=mamba2_state_dtype(self) + ) + + Qwen3_5TextConfig.layers_block_type = property(_layers_block_type) + Qwen3_5TextConfig.linear_layer_ids = property(_linear_layer_ids) + Qwen3_5TextConfig.full_attention_layer_ids = property(_full_attention_layer_ids) + Qwen3_5TextConfig.mamba2_cache_params = property(_mamba2_cache_params) + + original_hybrid_gdn_config = ModelRunner.hybrid_gdn_config.fget + + def _hybrid_gdn_config(self): + config = self.model_config.hf_config.get_text_config() + if isinstance(config, (Qwen3_5Config, Qwen3_5TextConfig, qwen3_5_model.Qwen3_5MoeConfig)): + _ensure_qwen35_attention_layer_ids(config) + return config + if isinstance(config, Qwen3_5VisionConfig): + text_config = getattr(config, "text_config", None) + if isinstance(text_config, (Qwen3_5Config, Qwen3_5TextConfig, qwen3_5_model.Qwen3_5MoeConfig)): + _ensure_qwen35_attention_layer_ids(text_config) + return text_config + return original_hybrid_gdn_config(self) + + ModelRunner.hybrid_gdn_config = property(_hybrid_gdn_config) + + +def _get_shadow_model_dir(model_path: str) -> Path: + source = Path(model_path).resolve() + source_hash = hashlib.sha256(str(source).encode("utf-8")).hexdigest()[:16] + cache_root = os.environ.get("SLIME_SGLANG_MODEL_CACHE_DIR") + if cache_root: + base_dir = Path(cache_root) + else: + base_dir = Path(tempfile.gettempdir()) / "slime-sglang-models" + return base_dir / f"qwen3_5_text_v6_{source_hash}" + + +def _populate_shadow_model_dir(source_dir: Path, target_dir: Path, hf_config) -> None: + for entry in source_dir.iterdir(): + if entry.name == "config.json": + continue + (target_dir / entry.name).symlink_to(entry) + + text_config = hf_config.text_config + text_config.architectures = ["Qwen3_5ForCausalLM"] + text_config.model_type = "qwen3_5_text" + text_config._name_or_path = str(source_dir) + config_dict = text_config.to_dict() + config_dict["architectures"] = ["Qwen3_5ForCausalLM"] + config_dict["model_type"] = "qwen3_5_text" + if "rope_theta" not in config_dict: + rope_theta = None + if isinstance(config_dict.get("rope_parameters"), dict): + rope_theta = config_dict["rope_parameters"].get("rope_theta") + if rope_theta is None and isinstance(config_dict.get("rope_scaling"), dict): + rope_theta = config_dict["rope_scaling"].get("rope_theta") + if rope_theta is not None: + config_dict["rope_theta"] = rope_theta + config_path = target_dir / "config.json" + config_path.write_text(json.dumps(config_dict, indent=2, sort_keys=True) + "\n") + + +def _compute_attention_layer_ids(config) -> tuple[list[int], list[int]]: + layer_types = getattr(config, "layer_types", None) or getattr(config, "layers_block_type", None) or [] + full_attention_layer_ids = [] + linear_attention_layer_ids = [] + for idx, layer_type in enumerate(layer_types): + if layer_type in {"full_attention", "attention"}: + full_attention_layer_ids.append(idx) + elif layer_type == "linear_attention": + linear_attention_layer_ids.append(idx) + return full_attention_layer_ids, linear_attention_layer_ids + + +def _ensure_qwen35_attention_layer_ids(config) -> None: + full_attention_layer_ids, linear_attention_layer_ids = _compute_attention_layer_ids(config) + if not hasattr(config, "full_attention_layer_ids"): + config.full_attention_layer_ids = full_attention_layer_ids + if not hasattr(config, "linear_attention_layer_ids"): + config.linear_attention_layer_ids = linear_attention_layer_ids diff --git a/slime/slime/backends/sglang_utils/sglang_engine.py b/slime/slime/backends/sglang_utils/sglang_engine.py index 2f50628787..30a4c70cd9 100644 --- a/slime/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/slime/backends/sglang_utils/sglang_engine.py @@ -13,6 +13,7 @@ from sglang.srt.utils import kill_process_tree from urllib3.exceptions import NewConnectionError +from .qwen3_5 import is_qwen35_model_path, maybe_prepare_qwen35_text_model, patch_sglang_qwen35 from slime.ray.ray_actor import RayActor from slime.utils.http_utils import get_host_info @@ -51,11 +52,9 @@ def _to_local_gpu_id(physical_gpu_id: int) -> int: def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: - from sglang.srt.entrypoints.http_server import launch_server - multiprocessing.set_start_method("spawn", force=True) server_args.host = server_args.host.strip("[]") - p = multiprocessing.Process(target=launch_server, args=(server_args,)) + p = multiprocessing.Process(target=_launch_server_entry, args=(server_args,)) p.start() if server_args.node_rank != 0: @@ -70,6 +69,13 @@ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process: return p +def _launch_server_entry(server_args: ServerArgs): + patch_sglang_qwen35() + from sglang.srt.entrypoints.http_server import launch_server + + launch_server(server_args) + + def _wait_server_healthy(base_url, api_key, is_process_alive): headers = { "Content-Type": "application/json; charset=utf-8", @@ -493,7 +499,22 @@ def _compute_server_args( ): is_prm = engine_role == "prm" gpus_per_engine = args.prm_num_gpus_per_engine if is_prm else args.rollout_num_gpus_per_engine - model_path = args.prm_model_path if is_prm else args.hf_checkpoint + original_model_path = args.prm_model_path if is_prm else (getattr(args, "rollout_model_path", None) or args.hf_checkpoint) + model_path = original_model_path + model_path = maybe_prepare_qwen35_text_model( + model_path, + language_only=getattr(args, "sglang_language_only", False), + ) + server_language_only = getattr(args, "sglang_language_only", False) + # Once Qwen3.5 has been materialized as a text-only shadow checkpoint, we should + # stop forwarding `language_only` to SGLang. Recent SGLang builds interpret the + # flag as encoder disaggregation and require `--encoder-urls`, even though the + # shadow checkpoint is already a plain text model. + if model_path != original_model_path and is_qwen35_model_path(model_path): + server_language_only = False + if is_qwen35_model_path(model_path) or is_qwen35_model_path(original_model_path): + os.environ["SLIME_ENABLE_QWEN35_SGLANG_PATCH"] = "1" + os.environ["SGLANG_EXTERNAL_MODEL_PACKAGE"] = "slime_plugins.sglang_models" nnodes = max(1, gpus_per_engine // args.num_gpus_per_node) node_rank = rank % nnodes @@ -546,6 +567,10 @@ def _compute_server_args( for attr in dataclasses.fields(ServerArgs): if worker_type == "decode" and attr.name == "enable_hierarchical_cache": continue + if attr.name == "language_only": + kwargs[attr.name] = server_language_only + unused_keys.discard(attr.name) + continue if hasattr(args, f"sglang_{attr.name}") and attr.name not in kwargs: kwargs[attr.name] = getattr(args, f"sglang_{attr.name}") unused_keys.discard(attr.name) diff --git a/slime/slime_plugins/mbridge/__init__.py b/slime/slime_plugins/mbridge/__init__.py index 18f6fe954a..8e3d27a274 100644 --- a/slime/slime_plugins/mbridge/__init__.py +++ b/slime/slime_plugins/mbridge/__init__.py @@ -2,6 +2,7 @@ from .glm4moe import GLM4MoEBridge from .glm4moe_lite import GLM4MoELiteBridge from .mimo import MimoBridge +from .qwen3_5 import Qwen35Bridge from .qwen3_next import Qwen3NextBridge -__all__ = ["GLM4Bridge", "GLM4MoEBridge", "GLM4MoELiteBridge", "Qwen3NextBridge", "MimoBridge"] +__all__ = ["GLM4Bridge", "GLM4MoEBridge", "GLM4MoELiteBridge", "Qwen35Bridge", "Qwen3NextBridge", "MimoBridge"] diff --git a/slime/slime_plugins/mbridge/qwen3_5.py b/slime/slime_plugins/mbridge/qwen3_5.py new file mode 100644 index 0000000000..5a7aa58d75 --- /dev/null +++ b/slime/slime_plugins/mbridge/qwen3_5.py @@ -0,0 +1,57 @@ +from mbridge.core import register_model + +from .qwen3_next import Qwen3NextBridge + + +@register_model("qwen3_5") +@register_model("qwen3_5_text") +class Qwen35Bridge(Qwen3NextBridge): + _ATTENTION_MAPPING = ( + Qwen3NextBridge._ATTENTION_MAPPING + | { + f"self_attention.{weight_name}": ["model.layers.{layer_number}." + weight_name] + for weight_name in [ + "input_layernorm.weight", + "linear_attn.A_log", + "linear_attn.conv1d.weight", + "linear_attn.dt_bias", + "linear_attn.in_proj_a.weight", + "linear_attn.in_proj_b.weight", + "linear_attn.in_proj_qkv.weight", + "linear_attn.in_proj_z.weight", + "linear_attn.norm.weight", + "linear_attn.out_proj.weight", + ] + } + ) + + _MLP_MAPPING = { + "mlp.linear_fc1.weight": [ + "model.layers.{layer_number}.mlp.gate_proj.weight", + "model.layers.{layer_number}.mlp.up_proj.weight", + ], + "mlp.linear_fc1.layer_norm_weight": ["model.layers.{layer_number}.post_attention_layernorm.weight"], + "mlp.linear_fc2.weight": ["model.layers.{layer_number}.mlp.down_proj.weight"], + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.full_hf_config = self.hf_config + if hasattr(self.hf_config, "text_config"): + self.hf_config = self.hf_config.text_config + + def _build_config(self): + mtp_args = {} + if hasattr(self.hf_config, "mtp_num_hidden_layers"): + mtp_args["mtp_num_layers"] = self.hf_config.mtp_num_hidden_layers + + return self._build_base_config( + use_cpu_initialization=False, + persist_layer_norm=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + qk_layernorm=True, + attention_output_gate=True, + rotary_interleaved=True, + **mtp_args, + ) diff --git a/slime/slime_plugins/megatron_bridge/__init__.py b/slime/slime_plugins/megatron_bridge/__init__.py index e69de29bb2..32c349576b 100644 --- a/slime/slime_plugins/megatron_bridge/__init__.py +++ b/slime/slime_plugins/megatron_bridge/__init__.py @@ -0,0 +1 @@ +import slime_plugins.megatron_bridge.qwen3_5 # noqa: F401 # register Qwen3.5 bridge diff --git a/slime/slime_plugins/megatron_bridge/qwen3_5.py b/slime/slime_plugins/megatron_bridge/qwen3_5.py new file mode 100644 index 0000000000..237f3da626 --- /dev/null +++ b/slime/slime_plugins/megatron_bridge/qwen3_5.py @@ -0,0 +1,173 @@ +""" +Qwen3.5 bridge for megatron.bridge. + +Registers `Qwen3_5ForConditionalGeneration` so that `AutoBridge.from_hf_pretrained` +recognises Qwen3.5 checkpoints and can provide a Megatron-compatible GPT model + +weight mappings. + +Qwen3.5 is a VLM with a hybrid (linear + full attention) text backbone. +For training we only use the text backbone, so this bridge delegates to the +existing Qwen3 provider infrastructure and adds linear-attention weight passthrough. +""" + +from __future__ import annotations + +import logging + +import torch +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.param_mapping import AutoMapping, GatedMLPMapping, QKVMapping +from megatron.core.models.gpt import GPTModel + +logger = logging.getLogger(__name__) + +# Use a string so we don't need transformers to have Qwen3.5 at import time +_Qwen3_5HF = "Qwen3_5ForConditionalGeneration" + + +def _get_text_config(hf_config): + """Unwrap text_config from VLM config if present.""" + return getattr(hf_config, "text_config", hf_config) + + +@MegatronModelBridge.register_bridge(source=_Qwen3_5HF, target=GPTModel) +class MegatronQwen35Bridge(MegatronModelBridge): + """Bridge between HuggingFace Qwen3.5 and Megatron GPTModel.""" + + def provider_bridge(self, hf_pretrained): + """Create a GPT ModelProvider from Qwen3.5 HF config.""" + from megatron.bridge.models.qwen.qwen_provider import Qwen3ModelProvider + + hf_config = hf_pretrained.config + text_config = _get_text_config(hf_config) + + model_dtype = self.dtype_from_hf(text_config, default=torch.bfloat16) + + rope_params = getattr(text_config, "rope_parameters", {}) or {} + rope_theta = rope_params.get("rope_theta", getattr(text_config, "rope_theta", 10000000)) + partial_rotary_factor = rope_params.get( + "partial_rotary_factor", getattr(text_config, "partial_rotary_factor", 0.25) + ) + + provider = Qwen3ModelProvider( + num_layers=text_config.num_hidden_layers, + hidden_size=text_config.hidden_size, + ffn_hidden_size=text_config.intermediate_size, + num_attention_heads=text_config.num_attention_heads, + num_query_groups=text_config.num_key_value_heads, + kv_channels=getattr(text_config, "head_dim", 256), + init_method_std=getattr(text_config, "initializer_range", 0.02), + layernorm_epsilon=text_config.rms_norm_eps, + gated_linear_unit=True, + make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(text_config.vocab_size), + rotary_base=rope_theta, + rotary_percent=partial_rotary_factor, + share_embeddings_and_output_weights=getattr(text_config, "tie_word_embeddings", True), + vocab_size=text_config.vocab_size, + seq_length=getattr(text_config, "max_position_embeddings", 262144), + fp16=(model_dtype == torch.float16), + bf16=(model_dtype == torch.bfloat16), + params_dtype=model_dtype, + # Qwen3.5 specific + qk_layernorm=True, + attention_output_gate=True, + ) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + """Weight mappings from HF Qwen3.5 to Megatron format. + + Qwen3.5 uses model.language_model.layers prefix (VLM structure) for weights, + but when loaded as text-only it may use model.layers prefix. + We handle both via the passthrough linear attention mappings. + """ + # Determine prefix: VLM uses model.language_model, text-only uses model + hf_config = self.hf_pretrained.config + text_config = _get_text_config(hf_config) + # Check if it's a VLM config (has text_config attribute) + is_vlm = hasattr(hf_config, "text_config") + pfx = "model.language_model" if is_vlm else "model" + + param_mappings = { + # Embeddings and output + f"embedding.word_embeddings.weight": f"{pfx}.embed_tokens.weight", + f"output_layer.weight": "lm_head.weight", + f"decoder.final_layernorm.weight": f"{pfx}.norm.weight", + # Attention: input layernorm (TE fused) + f"decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": f"{pfx}.layers.*.input_layernorm.weight", + # Attention: separate input layernorm + f"decoder.layers.*.input_layernorm.weight": f"{pfx}.layers.*.input_layernorm.weight", + # Attention output projection + f"decoder.layers.*.self_attention.linear_proj.weight": f"{pfx}.layers.*.self_attn.o_proj.weight", + # QK norms + f"decoder.layers.*.self_attention.q_layernorm.weight": f"{pfx}.layers.*.self_attn.q_norm.weight", + f"decoder.layers.*.self_attention.k_layernorm.weight": f"{pfx}.layers.*.self_attn.k_norm.weight", + # Post-attention layernorm + f"decoder.layers.*.mlp.linear_fc1.layer_norm_weight": f"{pfx}.layers.*.post_attention_layernorm.weight", + f"decoder.layers.*.pre_mlp_layernorm.weight": f"{pfx}.layers.*.post_attention_layernorm.weight", + # Dense MLP output + f"decoder.layers.*.mlp.linear_fc2.weight": f"{pfx}.layers.*.mlp.down_proj.weight", + } + + # Linear attention weights (passthrough — these live on the custom HF attention module) + linear_attn_weights = [ + "input_layernorm.weight", + "linear_attn.A_log", + "linear_attn.conv1d.weight", + "linear_attn.dt_bias", + "linear_attn.in_proj_a.weight", + "linear_attn.in_proj_b.weight", + "linear_attn.in_proj_qkv.weight", + "linear_attn.in_proj_z.weight", + "linear_attn.norm.weight", + "linear_attn.out_proj.weight", + ] + for w in linear_attn_weights: + param_mappings[f"decoder.layers.*.self_attention.{w}"] = f"{pfx}.layers.*.{w}" + + # Gated attention weights (for full_attention layers that use the spec path) + gated_attn_weights = [ + "self_attn.k_norm.weight", + "self_attn.k_proj.weight", + "self_attn.o_proj.weight", + "self_attn.q_norm.weight", + "self_attn.q_proj.weight", + "self_attn.v_proj.weight", + ] + for w in gated_attn_weights: + param_mappings[f"decoder.layers.*.self_attention.{w}"] = f"{pfx}.layers.*.{w}" + + mapping_list = [] + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param, hf_param)) + + # QKV mapping (for full attention layers) + mapping_list.append( + QKVMapping( + f"decoder.layers.*.self_attention.linear_qkv.weight", + [ + f"{pfx}.layers.*.self_attn.q_proj.weight", + f"{pfx}.layers.*.self_attn.k_proj.weight", + f"{pfx}.layers.*.self_attn.v_proj.weight", + ], + num_heads=text_config.num_attention_heads, + num_query_groups=text_config.num_key_value_heads, + head_dim=getattr(text_config, "head_dim", 256), + hidden_size=text_config.hidden_size, + ) + ) + + # Gated MLP (gate + up fused) + mapping_list.append( + GatedMLPMapping( + f"decoder.layers.*.mlp.linear_fc1.weight", + [ + f"{pfx}.layers.*.mlp.gate_proj.weight", + f"{pfx}.layers.*.mlp.up_proj.weight", + ], + ) + ) + + return MegatronMappingRegistry(mapping_list) diff --git a/slime/slime_plugins/models/qwen3_5.py b/slime/slime_plugins/models/qwen3_5.py new file mode 100644 index 0000000000..4ca36d9e6e --- /dev/null +++ b/slime/slime_plugins/models/qwen3_5.py @@ -0,0 +1,211 @@ +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import get_num_layers_to_build +from megatron.core.transformer.transformer_layer import get_transformer_layer_offset +from transformers import AutoConfig + +Qwen3_5Attention = None +Qwen3_5RMSNorm = None +FusedRMSNormGated = None +ShortConvolution = None +chunk_gated_delta_rule = None + +try: + from fla.modules import FusedRMSNormGated, ShortConvolution + from fla.ops.gated_delta_rule import chunk_gated_delta_rule + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5Attention, Qwen3_5RMSNorm +except ImportError: + pass + +from .hf_attention import HuggingfaceAttention + + +def _unwrap_text_config(hf_config): + return getattr(hf_config, "text_config", hf_config) + + +class Qwen35GatedDeltaNet(nn.Module): + """Varlen-compatible gated delta net for Qwen3.5 linear-attention layers.""" + + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_idx = layer_idx + self.layer_norm_epsilon = config.rms_norm_eps + + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = ShortConvolution( + hidden_size=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + ) + + self.in_proj_qkv = nn.Linear(self.hidden_size, self.conv_dim, bias=False) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + A = torch.empty(self.num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + self.norm = FusedRMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + device=torch.cuda.current_device(), + dtype=getattr(config, "dtype", None) or torch.get_current_dtype(), + ) + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + def fix_query_key_value_ordering(self, mixed_qkv): + new_tensor_shape_qkv = mixed_qkv.size()[:-1] + ( + self.num_k_heads, + 2 * self.head_k_dim + self.head_v_dim * self.num_v_heads // self.num_k_heads, + ) + mixed_qkv = mixed_qkv.view(*new_tensor_shape_qkv) + split_arg_list_qkv = [ + self.head_k_dim, + self.head_k_dim, + self.num_v_heads // self.num_k_heads * self.head_v_dim, + ] + query, key, value = torch.split(mixed_qkv, split_arg_list_qkv, dim=3) + value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim) + return query, key, value + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor = None, + ): + projected_qkv = self.in_proj_qkv(hidden_states) + projected_z = self.in_proj_z(hidden_states) + projected_b = self.in_proj_b(hidden_states) + projected_a = self.in_proj_a(hidden_states) + + query, key, value = self.fix_query_key_value_ordering(projected_qkv) + query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) + + mixed_qkv = torch.cat((query, key, value), dim=-1) + mixed_qkv, _ = self.conv1d( + x=mixed_qkv, + cu_seqlens=cu_seqlens, + ) + + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim) + key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim) + value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) + + gate = projected_z.reshape(projected_z.shape[0], projected_z.shape[1], -1, self.head_v_dim) + beta = projected_b.sigmoid() + alpha = projected_a + g = -self.A_log.float().exp() * F.softplus(alpha.float() + self.dt_bias) + + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + core_attn_out, _ = chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=True, + ) + + gate_shape = gate.shape + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + gate = gate.reshape(-1, gate.shape[-1]) + core_attn_out = self.norm(core_attn_out, gate) + core_attn_out = core_attn_out.reshape(gate_shape) + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) + + return self.out_proj(core_attn_out) + + +class Attention(HuggingfaceAttention): + def __init__( + self, + args, + config, + layer_number: int, + cp_comm_type: str = "p2p", + pg_collection=None, + ): + super().__init__( + args, + config, + layer_number, + cp_comm_type, + pg_collection, + ) + if ( + Qwen3_5Attention is None + or Qwen3_5RMSNorm is None + or FusedRMSNormGated is None + or ShortConvolution is None + or chunk_gated_delta_rule is None + ): + raise ImportError("Please install a transformers build with Qwen3.5 support.") + + self.hf_config = _unwrap_text_config(self.hf_config) + self.linear_attn = Qwen35GatedDeltaNet(self.hf_config, self.hf_layer_idx) + self.input_layernorm = Qwen3_5RMSNorm(self.hf_config.hidden_size, eps=self.hf_config.rms_norm_eps) + + def hf_forward(self, hidden_states, packed_seq_params): + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cu_seqlens=packed_seq_params.cu_seqlens_q, + ) + return hidden_states + + +def get_qwen3_5_spec(args, config, vp_stage): + kwargs = { + "use_transformer_engine": True, + } + if vp_stage is not None: + kwargs["vp_stage"] = vp_stage + transformer_layer_spec = get_gpt_decoder_block_spec(config, **kwargs) + + assert config.pipeline_model_parallel_layout is None, "not support this at the moment" + + num_layers_to_build = get_num_layers_to_build(config, vp_stage=vp_stage) + offset = get_transformer_layer_offset(config, vp_stage=vp_stage) + + hf_config = _unwrap_text_config(AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True)) + + for layer_id in range(num_layers_to_build): + if hf_config.layer_types[layer_id + offset] == "linear_attention": + layer_specs = copy.deepcopy(transformer_layer_spec.layer_specs[layer_id]) + layer_specs.submodules.self_attention = ModuleSpec( + module=Attention, + params={"args": args}, + ) + transformer_layer_spec.layer_specs[layer_id] = layer_specs + return transformer_layer_spec diff --git a/slime/slime_plugins/sglang_models/__init__.py b/slime/slime_plugins/sglang_models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/slime/slime_plugins/sglang_models/qwen3_5_causal.py b/slime/slime_plugins/sglang_models/qwen3_5_causal.py new file mode 100644 index 0000000000..ca3ccac23c --- /dev/null +++ b/slime/slime_plugins/sglang_models/qwen3_5_causal.py @@ -0,0 +1,9 @@ +import os + +if os.environ.get("SLIME_ENABLE_QWEN35_SGLANG_PATCH") == "1": + from slime.backends.sglang_utils.qwen3_5 import patch_sglang_qwen35 + patch_sglang_qwen35() + +from sglang.srt.models.qwen3_5 import Qwen3_5ForCausalLM + +EntryClass = [Qwen3_5ForCausalLM] From 80909a7f8ea9373c22be2679709dac4c9e34d8b3 Mon Sep 17 00:00:00 2001 From: miaoziyang Date: Mon, 23 Mar 2026 10:58:07 +0800 Subject: [PATCH 2/2] fix: patch Qwen3Config rope_theta compatibility for transformers>=5.0 and fix Qwen3.5-4B support --- .../run_qwen35_4b_openclaw_combine.sh | 4 +- .../run_qwen35_4b_openclaw_combine_lora.sh | 2 +- .../run_qwen3_4b_openclaw_combine.sh | 13 +- .../run_qwen3_4b_openclaw_combine_lora.sh | 2 +- slime/slime/utils/arguments.py | 17 +- .../slime_plugins/megatron_bridge/__init__.py | 16 ++ .../slime_plugins/megatron_bridge/qwen3_5.py | 51 +++--- slime/slime_plugins/models/qwen3_5.py | 159 ++++++++++-------- 8 files changed, 150 insertions(+), 114 deletions(-) diff --git a/openclaw-combine/run_qwen35_4b_openclaw_combine.sh b/openclaw-combine/run_qwen35_4b_openclaw_combine.sh index 0faef59e36..d051b6bbf5 100644 --- a/openclaw-combine/run_qwen35_4b_openclaw_combine.sh +++ b/openclaw-combine/run_qwen35_4b_openclaw_combine.sh @@ -41,7 +41,7 @@ SLIME_ROOT="$(cd -- "${SCRIPT_DIR}/../slime" &>/dev/null && pwd)" REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" source "${SLIME_ROOT}/scripts/models/qwen3.5-4B.sh" -HF_CKPT=${HF_CKPT:-${REPO_ROOT}/models/Qwen3.5-4B} +HF_CKPT=${HF_CKPT:-/path/to/models/Qwen/Qwen3.5-4B} REF_LOAD=${REF_LOAD:-${HF_CKPT}} SAVE_CKPT=${SAVE_CKPT:-${REPO_ROOT}/ckpt/qwen35-4b-openclaw-combine} PRM_MODEL_PATH=${PRM_MODEL_PATH:-${HF_CKPT}} @@ -69,7 +69,7 @@ CKPT_ARGS=( --ref-load "${REF_LOAD}" --save "${SAVE_CKPT}" --save-interval 100 - --rotary-base 5000000 + --rotary-base 10000000 ) ROLLOUT_ARGS=( diff --git a/openclaw-combine/run_qwen35_4b_openclaw_combine_lora.sh b/openclaw-combine/run_qwen35_4b_openclaw_combine_lora.sh index 9e3cfa4e51..17e221088b 100644 --- a/openclaw-combine/run_qwen35_4b_openclaw_combine_lora.sh +++ b/openclaw-combine/run_qwen35_4b_openclaw_combine_lora.sh @@ -41,7 +41,7 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" SLIME_ROOT="$(cd -- "${SCRIPT_DIR}/../slime" &>/dev/null && pwd)" -HF_CKPT=${HF_CKPT:-${REPO_ROOT}/models/Qwen3.5-4B} +HF_CKPT=${HF_CKPT:-/path/to/models/Qwen/Qwen3.5-4B} REF_LOAD=${REF_LOAD:-${HF_CKPT}} SAVE_CKPT=${SAVE_CKPT:-${REPO_ROOT}/ckpt/qwen35-4b-openclaw-combine-lora} PRM_MODEL_PATH=${PRM_MODEL_PATH:-${HF_CKPT}} diff --git a/openclaw-combine/run_qwen3_4b_openclaw_combine.sh b/openclaw-combine/run_qwen3_4b_openclaw_combine.sh index bdb71265a8..7b1ae3ec20 100644 --- a/openclaw-combine/run_qwen3_4b_openclaw_combine.sh +++ b/openclaw-combine/run_qwen3_4b_openclaw_combine.sh @@ -32,12 +32,13 @@ export RAY_num_heartbeats_timeout=60 SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" SLIME_ROOT="$(cd -- "${SCRIPT_DIR}/../slime" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" source "${SLIME_ROOT}/scripts/models/qwen3-4B.sh" -HF_CKPT=${HF_CKPT:-/data_storage/wyj/systems/huggingface/hub/Qwen3-4B-Thinking-2507} +HF_CKPT=${HF_CKPT:-/path/to/models/Qwen/Qwen3-4B} REF_LOAD=${REF_LOAD:-${HF_CKPT}} -SAVE_CKPT=${SAVE_CKPT:-/data_storage/wyj/OpenClaw-RL/ckpt/qwen3-4b-openclaw-combine} -PRM_MODEL_PATH=${PRM_MODEL_PATH:-/data_storage/wyj/systems/huggingface/hub/Qwen3-4B-Thinking-2507} +SAVE_CKPT=${SAVE_CKPT:-${REPO_ROOT}/ckpt/qwen3-4b-openclaw-combine} +PRM_MODEL_PATH=${PRM_MODEL_PATH:-${HF_CKPT}} export SGLANG_API_KEY="${SGLANG_API_KEY}" export SERVED_MODEL_NAME="qwen3-4b" @@ -61,7 +62,7 @@ CKPT_ARGS=( --ref-load "${REF_LOAD}" --save "${SAVE_CKPT}" --save-interval 100 - --rotary-base 5000000 + --rotary-base 1000000 ) ROLLOUT_ARGS=( @@ -176,7 +177,7 @@ ray start --head --node-ip-address "${MASTER_ADDR}" --num-gpus "${NUM_GPUS}" --d RUNTIME_ENV_JSON="{ \"env_vars\": { - \"PYTHONPATH\": \"/data_storage/wyj/OpenClaw-RL/Megatron-LM/:${SCRIPT_DIR}:${SCRIPT_DIR}/../openclaw-opd:${SLIME_ROOT}\", + \"PYTHONPATH\": \"${REPO_ROOT}/Megatron-LM/:${SCRIPT_DIR}:${SCRIPT_DIR}/../openclaw-opd:${SLIME_ROOT}\", \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", \"OPENCLAW_EVAL_MODE\": \"${OPENCLAW_EVAL_MODE}\", \"OPENCLAW_COMBINE_W_RL\": \"${OPENCLAW_COMBINE_W_RL}\", @@ -186,7 +187,7 @@ RUNTIME_ENV_JSON="{ ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ - -- python3 train_async.py \ + -- python3 "${SLIME_ROOT}/train_async.py" \ --actor-num-nodes 1 \ --actor-num-gpus-per-node "${ACTOR_GPUS}" \ --rollout-num-gpus "${ROLLOUT_GPUS}" \ diff --git a/openclaw-combine/run_qwen3_4b_openclaw_combine_lora.sh b/openclaw-combine/run_qwen3_4b_openclaw_combine_lora.sh index e2202c3945..b5eb351ffe 100755 --- a/openclaw-combine/run_qwen3_4b_openclaw_combine_lora.sh +++ b/openclaw-combine/run_qwen3_4b_openclaw_combine_lora.sh @@ -35,7 +35,7 @@ SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" SLIME_ROOT="$(cd -- "${SCRIPT_DIR}/../slime" &>/dev/null && pwd)" REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" -HF_CKPT=${HF_CKPT:-${REPO_ROOT}/models/Qwen3-4B} +HF_CKPT=${HF_CKPT:-/path/to/models/Qwen/Qwen3-4B} REF_LOAD=${REF_LOAD:-${HF_CKPT}} SAVE_CKPT=${SAVE_CKPT:-${REPO_ROOT}/ckpt/qwen3-4b-openclaw-combine-lora} PRM_MODEL_PATH=${PRM_MODEL_PATH:-${HF_CKPT}} diff --git a/slime/slime/utils/arguments.py b/slime/slime/utils/arguments.py index 9428e98f50..4e0a754a69 100644 --- a/slime/slime/utils/arguments.py +++ b/slime/slime/utils/arguments.py @@ -1904,10 +1904,21 @@ def equal(x, y): ("rms_norm_eps", "norm_epsilon", equal), ("rope_theta", "rotary_base", equal), ]: - if hasattr(hf_config, hf_config_name): - if not compare_fn(getattr(hf_config, hf_config_name), getattr(args, megatron_config_name)): + # Handle nested rope_parameters for models like Qwen3.5 + hf_value = None + if hf_config_name == "rope_theta" and hasattr(hf_config, "rope_parameters"): + rope_params = getattr(hf_config, "rope_parameters") + if isinstance(rope_params, dict) and "rope_theta" in rope_params: + hf_value = rope_params["rope_theta"] + elif hasattr(rope_params, "rope_theta"): + hf_value = getattr(rope_params, "rope_theta") + elif hasattr(hf_config, hf_config_name): + hf_value = getattr(hf_config, hf_config_name) + + if hf_value is not None: + if not compare_fn(hf_value, getattr(args, megatron_config_name)): errors.append( - f"{hf_config_name} in hf config {getattr(hf_config, hf_config_name)} is not equal to " + f"{hf_config_name} in hf config {hf_value} is not equal to " f"{megatron_config_name} {getattr(args, megatron_config_name)}, please check the config." ) diff --git a/slime/slime_plugins/megatron_bridge/__init__.py b/slime/slime_plugins/megatron_bridge/__init__.py index 32c349576b..3aec20c9fe 100644 --- a/slime/slime_plugins/megatron_bridge/__init__.py +++ b/slime/slime_plugins/megatron_bridge/__init__.py @@ -1 +1,17 @@ +# rope_theta_compat: monkey-patch Qwen3Config for transformers>=5.0 +try: + from transformers.models.qwen3.configuration_qwen3 import Qwen3Config as _Qwen3Config + + _qwen3_orig_init = _Qwen3Config.__init__ + + def _qwen3_patched_init(self, *args, **kwargs): + _rope_theta = kwargs.pop("rope_theta", None) + _qwen3_orig_init(self, *args, **kwargs) + if not hasattr(self, "rope_theta") or self.rope_theta is None: + self.rope_theta = _rope_theta if _rope_theta is not None else 1000000 + + _Qwen3Config.__init__ = _qwen3_patched_init +except Exception: + pass + import slime_plugins.megatron_bridge.qwen3_5 # noqa: F401 # register Qwen3.5 bridge diff --git a/slime/slime_plugins/megatron_bridge/qwen3_5.py b/slime/slime_plugins/megatron_bridge/qwen3_5.py index 237f3da626..2018a17cc0 100644 --- a/slime/slime_plugins/megatron_bridge/qwen3_5.py +++ b/slime/slime_plugins/megatron_bridge/qwen3_5.py @@ -35,6 +35,15 @@ def _get_text_config(hf_config): class MegatronQwen35Bridge(MegatronModelBridge): """Bridge between HuggingFace Qwen3.5 and Megatron GPTModel.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hf_pretrained = None + + def load_weights_hf_to_megatron(self, hf_pretrained, model): + """Store hf_pretrained before calling parent's load method.""" + self.hf_pretrained = hf_pretrained + return super().load_weights_hf_to_megatron(hf_pretrained, model) + def provider_bridge(self, hf_pretrained): """Create a GPT ModelProvider from Qwen3.5 HF config.""" from megatron.bridge.models.qwen.qwen_provider import Qwen3ModelProvider @@ -84,6 +93,11 @@ def mapping_registry(self) -> MegatronMappingRegistry: We handle both via the passthrough linear attention mappings. """ # Determine prefix: VLM uses model.language_model, text-only uses model + if self.hf_pretrained is None: + raise RuntimeError( + "hf_pretrained is not set. Ensure load_weights_hf_to_megatron() " + "is called before mapping_registry()." + ) hf_config = self.hf_pretrained.config text_config = _get_text_config(hf_config) # Check if it's a VLM config (has text_config attribute) @@ -127,47 +141,28 @@ def mapping_registry(self) -> MegatronMappingRegistry: for w in linear_attn_weights: param_mappings[f"decoder.layers.*.self_attention.{w}"] = f"{pfx}.layers.*.{w}" - # Gated attention weights (for full_attention layers that use the spec path) - gated_attn_weights = [ - "self_attn.k_norm.weight", - "self_attn.k_proj.weight", - "self_attn.o_proj.weight", - "self_attn.q_norm.weight", - "self_attn.q_proj.weight", - "self_attn.v_proj.weight", - ] - for w in gated_attn_weights: - param_mappings[f"decoder.layers.*.self_attention.{w}"] = f"{pfx}.layers.*.{w}" mapping_list = [] for megatron_param, hf_param in param_mappings.items(): - mapping_list.append(AutoMapping(megatron_param, hf_param)) + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) # QKV mapping (for full attention layers) mapping_list.append( QKVMapping( - f"decoder.layers.*.self_attention.linear_qkv.weight", - [ - f"{pfx}.layers.*.self_attn.q_proj.weight", - f"{pfx}.layers.*.self_attn.k_proj.weight", - f"{pfx}.layers.*.self_attn.v_proj.weight", - ], - num_heads=text_config.num_attention_heads, - num_query_groups=text_config.num_key_value_heads, - head_dim=getattr(text_config, "head_dim", 256), - hidden_size=text_config.hidden_size, + megatron_param=f"decoder.layers.*.self_attention.linear_qkv.weight", + q=f"{pfx}.layers.*.self_attn.q_proj.weight", + k=f"{pfx}.layers.*.self_attn.k_proj.weight", + v=f"{pfx}.layers.*.self_attn.v_proj.weight", ) ) # Gated MLP (gate + up fused) mapping_list.append( GatedMLPMapping( - f"decoder.layers.*.mlp.linear_fc1.weight", - [ - f"{pfx}.layers.*.mlp.gate_proj.weight", - f"{pfx}.layers.*.mlp.up_proj.weight", - ], + megatron_param=f"decoder.layers.*.mlp.linear_fc1.weight", + gate=f"{pfx}.layers.*.mlp.gate_proj.weight", + up=f"{pfx}.layers.*.mlp.up_proj.weight", ) ) - return MegatronMappingRegistry(mapping_list) + return MegatronMappingRegistry(*mapping_list) diff --git a/slime/slime_plugins/models/qwen3_5.py b/slime/slime_plugins/models/qwen3_5.py index 4ca36d9e6e..18827dc87e 100644 --- a/slime/slime_plugins/models/qwen3_5.py +++ b/slime/slime_plugins/models/qwen3_5.py @@ -7,30 +7,31 @@ from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_block import get_num_layers_to_build from megatron.core.transformer.transformer_layer import get_transformer_layer_offset -from transformers import AutoConfig - -Qwen3_5Attention = None -Qwen3_5RMSNorm = None -FusedRMSNormGated = None -ShortConvolution = None -chunk_gated_delta_rule = None +from transformers.activations import ACT2FN try: from fla.modules import FusedRMSNormGated, ShortConvolution from fla.ops.gated_delta_rule import chunk_gated_delta_rule - from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5Attention, Qwen3_5RMSNorm except ImportError: pass -from .hf_attention import HuggingfaceAttention +from .hf_attention import HuggingfaceAttention, _load_hf_config -def _unwrap_text_config(hf_config): - return getattr(hf_config, "text_config", hf_config) +def _get_text_config(hf_config): + """Extract text config from a VLM config if needed.""" + if hasattr(hf_config, "text_config"): + return hf_config.text_config + return hf_config -class Qwen35GatedDeltaNet(nn.Module): - """Varlen-compatible gated delta net for Qwen3.5 linear-attention layers.""" +# Adapted from Qwen3NextGatedDeltaNet but with separate in_proj_qkv and in_proj_z +class Qwen3_5GatedDeltaNet(nn.Module): + """ + Qwen3.5 GatedDeltaNet with varlen support. + Unlike Qwen3Next which uses a combined in_proj_qkvz, Qwen3.5 uses + separate in_proj_qkv (for Q,K,V) and in_proj_z (for Z). + """ def __init__(self, config, layer_idx: int): super().__init__() @@ -44,8 +45,11 @@ def __init__(self, config, layer_idx: int): self.conv_kernel_size = config.linear_conv_kernel_dim self.layer_idx = layer_idx + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] self.layer_norm_epsilon = config.rms_norm_eps + # QKV self.conv_dim = self.key_dim * 2 + self.value_dim self.conv1d = ShortConvolution( hidden_size=self.conv_dim, @@ -53,80 +57,68 @@ def __init__(self, config, layer_idx: int): kernel_size=self.conv_kernel_size, ) - self.in_proj_qkv = nn.Linear(self.hidden_size, self.conv_dim, bias=False) - self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + # Separate projections for QKV and Z (unlike Qwen3Next which combines QKVZ) + projection_size_qkv = self.key_dim * 2 + self.value_dim + projection_size_z = self.value_dim + self.in_proj_qkv = nn.Linear(self.hidden_size, projection_size_qkv, bias=False) + self.in_proj_z = nn.Linear(self.hidden_size, projection_size_z, bias=False) self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + # time step projection self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + A = torch.empty(self.num_v_heads).uniform_(0, 16) self.A_log = nn.Parameter(torch.log(A)) self.norm = FusedRMSNormGated( self.head_v_dim, eps=self.layer_norm_epsilon, + activation=self.activation, device=torch.cuda.current_device(), - dtype=getattr(config, "dtype", None) or torch.get_current_dtype(), + dtype=config.dtype if config.dtype is not None else torch.get_current_dtype(), ) - self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) - def fix_query_key_value_ordering(self, mixed_qkv): - new_tensor_shape_qkv = mixed_qkv.size()[:-1] + ( - self.num_k_heads, - 2 * self.head_k_dim + self.head_v_dim * self.num_v_heads // self.num_k_heads, - ) - mixed_qkv = mixed_qkv.view(*new_tensor_shape_qkv) - split_arg_list_qkv = [ - self.head_k_dim, - self.head_k_dim, - self.num_v_heads // self.num_k_heads * self.head_v_dim, - ] - query, key, value = torch.split(mixed_qkv, split_arg_list_qkv, dim=3) - value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim) - return query, key, value + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor = None, ): - projected_qkv = self.in_proj_qkv(hidden_states) - projected_z = self.in_proj_z(hidden_states) - projected_b = self.in_proj_b(hidden_states) - projected_a = self.in_proj_a(hidden_states) + batch_size, seq_len, _ = hidden_states.shape - query, key, value = self.fix_query_key_value_ordering(projected_qkv) - query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) + # Projections (flat layout: [Q_all, K_all, V_all]) + mixed_qkv = self.in_proj_qkv(hidden_states) + z = self.in_proj_z(hidden_states) + z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) - mixed_qkv = torch.cat((query, key, value), dim=-1) + # Convolution on the flat QKV mixed_qkv, _ = self.conv1d( x=mixed_qkv, cu_seqlens=cu_seqlens, ) + # Split into Q, K, V (flat split, matching HF layout) query, key, value = torch.split( mixed_qkv, - [ - self.key_dim, - self.key_dim, - self.value_dim, - ], + [self.key_dim, self.key_dim, self.value_dim], dim=-1, ) - query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim) - key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim) - value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) - - gate = projected_z.reshape(projected_z.shape[0], projected_z.shape[1], -1, self.head_v_dim) - beta = projected_b.sigmoid() - alpha = projected_a - g = -self.A_log.float().exp() * F.softplus(alpha.float() + self.dt_bias) + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) if self.num_v_heads // self.num_k_heads > 1: query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - core_attn_out, _ = chunk_gated_delta_rule( + core_attn_out, last_recurrent_state = chunk_gated_delta_rule( query, key, value, @@ -135,16 +127,19 @@ def forward( initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, ) - gate_shape = gate.shape + z_shape_og = z.shape + # reshape input data into 2D tensor core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - gate = gate.reshape(-1, gate.shape[-1]) - core_attn_out = self.norm(core_attn_out, gate) - core_attn_out = core_attn_out.reshape(gate_shape) - core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) - return self.out_proj(core_attn_out) + output = self.out_proj(core_attn_out) + return output class Attention(HuggingfaceAttention): @@ -163,18 +158,21 @@ def __init__( cp_comm_type, pg_collection, ) - if ( - Qwen3_5Attention is None - or Qwen3_5RMSNorm is None - or FusedRMSNormGated is None - or ShortConvolution is None - or chunk_gated_delta_rule is None - ): - raise ImportError("Please install a transformers build with Qwen3.5 support.") - - self.hf_config = _unwrap_text_config(self.hf_config) - self.linear_attn = Qwen35GatedDeltaNet(self.hf_config, self.hf_layer_idx) - self.input_layernorm = Qwen3_5RMSNorm(self.hf_config.hidden_size, eps=self.hf_config.rms_norm_eps) + # Qwen3.5 is a VLM model with nested text_config + self.hf_config = _get_text_config(self.hf_config) + self.hf_config._attn_implementation = "flash_attention_2" + + self.linear_attn = Qwen3_5GatedDeltaNet(self.hf_config, self.hf_layer_idx) + + # Use a simple RMSNorm + try: + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextRMSNorm + + self.input_layernorm = Qwen3NextRMSNorm(self.hf_config.hidden_size, eps=self.hf_config.rms_norm_eps) + except ImportError: + from torch.nn import RMSNorm + + self.input_layernorm = RMSNorm(self.hf_config.hidden_size, eps=self.hf_config.rms_norm_eps) def hf_forward(self, hidden_states, packed_seq_params): hidden_states = self.input_layernorm(hidden_states) @@ -186,6 +184,11 @@ def hf_forward(self, hidden_states, packed_seq_params): def get_qwen3_5_spec(args, config, vp_stage): + # always use the moe path for MoE models + if not args.num_experts: + config.moe_layer_freq = [0] * config.num_layers + + # Define the decoder block spec kwargs = { "use_transformer_engine": True, } @@ -195,17 +198,27 @@ def get_qwen3_5_spec(args, config, vp_stage): assert config.pipeline_model_parallel_layout is None, "not support this at the moment" + # Slice the layer specs to only include the layers that are built in this pipeline stage. num_layers_to_build = get_num_layers_to_build(config, vp_stage=vp_stage) offset = get_transformer_layer_offset(config, vp_stage=vp_stage) - hf_config = _unwrap_text_config(AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True)) + hf_config = _load_hf_config(args.hf_checkpoint) + text_config = _get_text_config(hf_config) + + # Compute layer_types if the config class doesn't expose it + if not hasattr(text_config, "layer_types"): + interval = getattr(text_config, "full_attention_interval", 4) + n = text_config.num_hidden_layers + text_config.layer_types = [ + "full_attention" if (i + 1) % interval == 0 else "linear_attention" for i in range(n) + ] for layer_id in range(num_layers_to_build): - if hf_config.layer_types[layer_id + offset] == "linear_attention": + if text_config.layer_types[layer_id + offset] == "linear_attention": layer_specs = copy.deepcopy(transformer_layer_spec.layer_specs[layer_id]) layer_specs.submodules.self_attention = ModuleSpec( module=Attention, params={"args": args}, ) transformer_layer_spec.layer_specs[layer_id] = layer_specs - return transformer_layer_spec + return transformer_layer_spec \ No newline at end of file