feat: add DeepGEMM JIT kernel warmup integrated into C++ engine startup#967
Open
ydshi0 wants to merge 3 commits intoalibaba:mainfrom
Open
feat: add DeepGEMM JIT kernel warmup integrated into C++ engine startup#967ydshi0 wants to merge 3 commits intoalibaba:mainfrom
ydshi0 wants to merge 3 commits intoalibaba:mainfrom
Conversation
Collaborator
AI Code Review - PR #967Status: BLOCKING Summary: P0/1 · P1/1 · P2/3 · P3/3 Blocking IssuesP0
P1
Non-blocking SuggestionsP2
P3
Checklist Violations (12 fail / 84 total)General Principles Checklist
RTP-LLM Checklist
Python Static-First Checklist
Strengths
|
Collaborator
AI Code Review - PR #967Status: LGTM Summary: P0/0 · P1/0 · P2/3 · P3/2 lgtm ready to ci Non-blocking SuggestionsP2
P3
Checklist Violations (7 fail / 84 total)General Principles Checklist
RTP-LLM Checklist
Python Static-First Checklist
Strengths
|
Collaborator
AI Code Review - PR #967Status: LGTM Summary: P0/0 · P1/0 · P2/2 · P3/0 lgtm ready to ci Non-blocking SuggestionsP2
Checklist Violations (4 fail / 56 total)General Principles Checklist
Python Static-First Checklist
Strengths
|
- rtp_llm/models_py/kernels/cuda/deepgemm_warmup.py (new):
DeepGEMM JIT warmup implementation — pre-compiles fp8_gemm_nt (dense)
and m_grouped_fp8_gemm_nt_masked (MoE) kernels for representative M values
to avoid first-request nvcc latency. GroupedGemmWarmupWrapper encapsulates
MoE dummy tensor construction behind a forward() interface.
- rtp_llm/models_py/model_desc/module_base.py:
Add warmup_deep_gemm() method to GptModelBase, reads deepgemm_warmup
config and delegates to deepgemm_warmup.deep_gemm_warmup().
- rtp_llm/cpp/normal_engine/NormalEngine.cc:
Call py_model.warmup_deep_gemm() in NormalEngine constructor before
engine warmUp(), so JIT kernels are compiled before fake inference.
- rtp_llm/config/model_config.py:
Add deepgemm_warmup field to ModelConfig (1=enabled, 0=disabled,
full=all M values).
- rtp_llm/models_py/BUILD:
Add visibility for kernels target to allow subpackage test access.
- rtp_llm/models_py/kernels/cuda/test/BUILD + test_deepgemm_warmup.py (new):
Unit tests verifying warmup calls fp8 linear forward() and MoE
m_grouped_fp8_gemm_nt_masked, plus JIT cache file write/reuse.
P0/P1 (Blocking): GIL 获取 — NormalEngine.cc:81-90 问题: NormalEngine 构造函数在 LocalRpcServer 的 gil_scoped_release 作用域内调用 Python (py::hasattr, attr(...)(), e.what()),未持有 GIL,会导致崩溃或未定义行为。 修复: 在调用 Python 前插入 py::gil_scoped_acquire gil;,且作用域覆盖整个 if 块(包括 catch 中的 e.what()),确保所有 Python API 调用都在 GIL 保护下。 P2: deepgemm warmup 受 runtime_config.warm_up 总开关约束 — NormalEngine.cc:81 问题: DeepGEMM warmup 独立于 runtime_config.warm_up 总开关,即使用户禁用 warmup 也会执行。 修复: 将 runtime_config.warm_up 加入外层条件判断,与引擎 warmup 保持一致。deepgemm_warmup=0 仍作为细粒度禁用旋钮。 P2: 消除 8192 硬编码魔数 — module_base.py:113 问题: max_tokens 上限 8192 是裸 magic number。 修复: 提取为类常量 DEEPGEMM_WARMUP_MAX_TOKENS = 8192,带注释说明用途。 P2: deepgemm_warmup 值校验 — model_config.py:529 + deepgemm_warmup.py:335-342 问题: deepgemm_warmup 是裸字符串 tri-state,写错值会静默回退。 修复: - model_config.py: 类型标注改为 Literal[0, 1, full] - deepgemm_warmup.py: 在函数入口添加合法值校验,非法值打 warning 并回退到 1 P3: 收窄异常捕获 — module_base.py:124 问题: 裸 except Exception 吞掉所有异常,无法区分可恢复 vs 致命错误。 修复: 改为 except (ImportError, RuntimeError),仅捕获已知可能的异常(包缺失、CUDA 运行时错误)。其他异常向上抛出。同时添加 exc_info=True 让完整 traceback 进入日志。 未修改项说明 - P3 模块级全局集合 (_FP8_GEMM_NT_WARMED/_GROUPED_GEMM_WARMED): 当前进程模型为单 model 实例,跨实例污染不会发生,保持现状。 - P3 BUILD 风格: 属代码风格问题,不影响功能,暂不修改。 - Checklist getattr → 直接属性访问: 已修复,改为 self.config.deepgemm_warmup 直接读取。
90bc27d to
a8c8256
Compare
Collaborator
AI Code Review - PR #967Status: LGTM Summary: P0/0 · P1/0 · P2/3 · P3/0 lgtm ready to ci Non-blocking SuggestionsP2
Checklist Violations (6 fail / 56 total)General Principles Checklist
Python Static-First Checklist
Strengths
|
a8c8256 to
a28b3a1
Compare
Collaborator
AI Code Review - PR #967Status: LGTM Summary: P0/0 · P1/0 · P2/1 · P3/0 lgtm ready to ci Non-blocking SuggestionsP2
Checklist Violations (6 fail / 56 total)General Principles Checklist
Python Static-First Checklist
Strengths
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
rtp_llm/models_py/kernels/cuda/deepgemm_warmup.py (new): DeepGEMM JIT warmup implementation — pre-compiles fp8_gemm_nt (dense) and m_grouped_fp8_gemm_nt_masked (MoE) kernels for representative M values to avoid first-request nvcc latency. GroupedGemmWarmupWrapper encapsulates MoE dummy tensor construction behind a forward() interface.
rtp_llm/models_py/model_desc/module_base.py: Add warmup_deep_gemm() method to GptModelBase, reads deepgemm_warmup config and delegates to deepgemm_warmup.deep_gemm_warmup().
rtp_llm/cpp/normal_engine/NormalEngine.cc: Call py_model.warmup_deep_gemm() in NormalEngine constructor before engine warmUp(), so JIT kernels are compiled before fake inference.
rtp_llm/config/model_config.py: Add deepgemm_warmup field to ModelConfig (1=enabled, 0=disabled, full=all M values).
rtp_llm/models_py/BUILD: Add visibility for kernels target to allow subpackage test access.
rtp_llm/models_py/kernels/cuda/test/BUILD + test_deepgemm_warmup.py (new): Unit tests verifying warmup calls fp8 linear forward() and MoE m_grouped_fp8_gemm_nt_masked, plus JIT cache file write/reuse.