Cuda graph replays on capture error#1253
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdd a lazy manual_cuda_graph_t wrapper and migrate manual CUDA-graph capture/instantiate/launch call sites (PDHG, feasibility jump, ping-pong graph, weighted-average, adaptive step-size) to its run(stream, work) API; re-enable previously skipped incumbent callback tests. ChangesCUDA Graph Abstraction and Solver Integration
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Suggested labels
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
cpp/src/utilities/manual_cuda_graph.cuh (2)
136-145:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winDouble-free of CUDA graph handle when
dummy == parent.With
cudaStreamBeginCaptureToGraph, the captured graph returned bycudaStreamEndCaptureis the same handle as the parent graph passed in. If user code throws mid-capture and capture wasn't invalidated,dummy == parent, so line 142 destroys it, then line 144 destroys the same handle again.🐛 Proposed fix: skip destroying dummy since it aliases parent
~capture_guard_t() noexcept { if (capture_active) { cudaGraph_t dummy = nullptr; // best-effort; we're already unwinding cudaStreamEndCapture(stream, &dummy); - if (dummy != nullptr) { cudaGraphDestroy(dummy); } + // dummy == parent for manual capture; destroying parent below handles it } - if (parent != nullptr) { cudaGraphDestroy(parent); } + if (parent != nullptr) { RAFT_CUDA_TRY_NO_THROW(cudaGraphDestroy(parent)); } }🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/src/utilities/manual_cuda_graph.cuh` around lines 136 - 145, The destructor ~capture_guard_t() can double-destroy the same cudaGraph_t when cudaStreamEndCapture returns the same handle as parent; modify the cleanup so that after calling cudaStreamEndCapture(stream, &dummy) you only call cudaGraphDestroy on dummy if dummy != parent (or alternatively only destroy parent and skip destroying dummy when they alias), ensuring you still destroy parent if it is non-null and avoid calling cudaGraphDestroy twice on the same handle (refer to symbols: ~capture_guard_t, capture_active, cudaStreamEndCapture, dummy, parent, cudaGraphDestroy).
84-85:⚠️ Potential issue | 🔴 Critical
cudaStreamBeginCaptureToGraphrequires CUDA 12.3+ but the project supports CUDA 12.0+.This API was introduced in CUDA 12.3. The unconditional use at lines 84–85 will cause build failures on CUDA 12.0, 12.1, and 12.2. Either add a version guard with a fallback implementation or update the project's minimum CUDA requirement to 12.3.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/src/utilities/manual_cuda_graph.cuh` around lines 84 - 85, The call to cudaStreamBeginCaptureToGraph inside manual_cuda_graph.cuh (wrapped by RAFT_CUDA_TRY) requires CUDA 12.3+, but the project supports CUDA 12.0–12.2; guard the call with a CUDA version check (e.g., `#if` defined(CUDA_VERSION) && CUDA_VERSION >= 12030) and provide a fallback for older toolkits (call the older cudaStreamBeginCapture API or cudaStreamBeginCapture(stream.value(), cudaStreamCaptureModeThreadLocal) within the RAFT_CUDA_TRY) so builds on CUDA 12.0–12.2 use the compatible capture API; ensure both branches use the same error handling macro (RAFT_CUDA_TRY) and keep the unique symbol names (cudaStreamBeginCaptureToGraph, cudaStreamBeginCapture, stream.value(), RAFT_CUDA_TRY) to locate and implement the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@cpp/src/utilities/manual_cuda_graph.cuh`:
- Around line 136-145: The destructor ~capture_guard_t() can double-destroy the
same cudaGraph_t when cudaStreamEndCapture returns the same handle as parent;
modify the cleanup so that after calling cudaStreamEndCapture(stream, &dummy)
you only call cudaGraphDestroy on dummy if dummy != parent (or alternatively
only destroy parent and skip destroying dummy when they alias), ensuring you
still destroy parent if it is non-null and avoid calling cudaGraphDestroy twice
on the same handle (refer to symbols: ~capture_guard_t, capture_active,
cudaStreamEndCapture, dummy, parent, cudaGraphDestroy).
- Around line 84-85: The call to cudaStreamBeginCaptureToGraph inside
manual_cuda_graph.cuh (wrapped by RAFT_CUDA_TRY) requires CUDA 12.3+, but the
project supports CUDA 12.0–12.2; guard the call with a CUDA version check (e.g.,
`#if` defined(CUDA_VERSION) && CUDA_VERSION >= 12030) and provide a fallback for
older toolkits (call the older cudaStreamBeginCapture API or
cudaStreamBeginCapture(stream.value(), cudaStreamCaptureModeThreadLocal) within
the RAFT_CUDA_TRY) so builds on CUDA 12.0–12.2 use the compatible capture API;
ensure both branches use the same error handling macro (RAFT_CUDA_TRY) and keep
the unique symbol names (cudaStreamBeginCaptureToGraph, cudaStreamBeginCapture,
stream.value(), RAFT_CUDA_TRY) to locate and implement the change.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: f4efb04c-5af5-4c02-baa4-c8ea9d79fa41
📒 Files selected for processing (1)
cpp/src/utilities/manual_cuda_graph.cuh
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
cpp/src/utilities/manual_cuda_graph.cuh (1)
76-92:⚠️ Potential issue | 🟠 Major | ⚡ Quick winDestroy the captured graph on instantiation failure.
After a successful
cudaStreamEndCapture,capturedowns a graph. IfcudaGraphInstantiatefails, the throwing macro skips thecudaGraphDestroybelow, so each failed first-run/retry leaks a graph handle.♻️ Proposed fix
cudaGraph_t captured = nullptr; cudaError_t end_err = cudaStreamEndCapture(stream.value(), &captured); guard.capture_active = false; @@ - RAFT_CUDA_TRY(cudaGraphInstantiate(&instance_, captured)); - RAFT_CUDA_TRY(cudaGraphDestroy(captured)); + try { + RAFT_CUDA_TRY(cudaGraphInstantiate(&instance_, captured)); + RAFT_CUDA_TRY(cudaGraphDestroy(captured)); + captured = nullptr; + } catch (...) { + if (captured != nullptr) { RAFT_CUDA_TRY_NO_THROW(cudaGraphDestroy(captured)); } + throw; + }As per coding guidelines, "Flag missing RAII in exception paths since cuOpt uses exceptions."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/src/utilities/manual_cuda_graph.cuh` around lines 76 - 92, After cudaStreamEndCapture succeeds and before calling cudaGraphInstantiate(&instance_, captured), ensure that on instantiate failure the captured graph is destroyed to avoid leaking the graph handle; replace or wrap the RAFT_CUDA_TRY(cudaGraphInstantiate(&instance_, captured)) call so that if cudaGraphInstantiate returns an error you call cudaGraphDestroy(captured) and then propagate the error (or rethrow) rather than letting the macro skip the destroy; reference symbols: captured, cudaGraphInstantiate, cudaGraphDestroy, RAFT_CUDA_TRY, and instance_.
🧹 Nitpick comments (1)
cpp/src/utilities/manual_cuda_graph.cuh (1)
23-34: ⚡ Quick winDocument that invalidation recovery replays
work()on the host.The fallback path preserves device-side results, but any host-side mutation inside
workhas already happened once during the failed capture attempt and will happen again here. Please call out thatworkmust be host-idempotent, or keep host bookkeeping outside the callable.📝 Suggested clarification
// Wrapper around a CUDA graph captured from a callable. CUB / Thrust / RAFT / // cuSPARSE calls inside the captured region are preserved. +// `work` must not perform non-idempotent host-side mutations: if capture is +// invalidated, it is executed once during the failed capture attempt and once +// again in the eager fallback path. // // Invalidation recovery:Also applies to: 80-87
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/src/utilities/manual_cuda_graph.cuh` around lines 23 - 34, Add a sentence to the invalidation recovery doc block explaining that when cudaStreamEndCapture returns cudaErrorStreamCaptureInvalidated the wrapper drains the sticky error and re-executes the provided callable (work()) on the host, so any host-side mutations inside work() will run twice; update text near the description of cudaStreamEndCapture, work(), and run to state that work must be host-idempotent or that host bookkeeping should be moved out of the callable to avoid double application during recovery.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@cpp/src/utilities/manual_cuda_graph.cuh`:
- Around line 76-92: After cudaStreamEndCapture succeeds and before calling
cudaGraphInstantiate(&instance_, captured), ensure that on instantiate failure
the captured graph is destroyed to avoid leaking the graph handle; replace or
wrap the RAFT_CUDA_TRY(cudaGraphInstantiate(&instance_, captured)) call so that
if cudaGraphInstantiate returns an error you call cudaGraphDestroy(captured) and
then propagate the error (or rethrow) rather than letting the macro skip the
destroy; reference symbols: captured, cudaGraphInstantiate, cudaGraphDestroy,
RAFT_CUDA_TRY, and instance_.
---
Nitpick comments:
In `@cpp/src/utilities/manual_cuda_graph.cuh`:
- Around line 23-34: Add a sentence to the invalidation recovery doc block
explaining that when cudaStreamEndCapture returns
cudaErrorStreamCaptureInvalidated the wrapper drains the sticky error and
re-executes the provided callable (work()) on the host, so any host-side
mutations inside work() will run twice; update text near the description of
cudaStreamEndCapture, work(), and run to state that work must be host-idempotent
or that host bookkeeping should be moved out of the callable to avoid double
application during recovery.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 376b74a0-8809-454f-88b6-7a0b5a4546bc
📒 Files selected for processing (1)
cpp/src/utilities/manual_cuda_graph.cuh
| if (use_graph) { | ||
| step_graph_.run(climber_stream, step_body); | ||
| } else { | ||
| step_body(); |
| // caution Binary part is because in pdlp we swap pointers instead of copying vectors to accept a | ||
| // valid pdhg step So every odd pdlp step it's one graph, every even step it's another graph | ||
| // Two-slot CUDA-graph cache for PDLP. PDLP swaps pointers (rather than | ||
| // copying vectors) at the end of every pdhg step, so the captured graph |
There was a problem hiding this comment.
I would just say adaptive pdhg step. There is no pointer swap when using fixed pdhg step (which is what we do for default/Stable3)
| // Currently graph capture is not supported for cuSparse SpMM | ||
| // TODO enable once cuSparse SpMM supports graph capture | ||
| graph_all{stream_view_, is_legacy_batch_mode || batch_mode_}, | ||
| graph_all_non_major{stream_view_, is_legacy_batch_mode || batch_mode_}, |
There was a problem hiding this comment.
Why not keep a single graph_all and use the is_major to swap between major and non-major like before?
There was a problem hiding this comment.
Clarifying things on slack.
There was a problem hiding this comment.
Confirm this is wrong, reverting back the previous graph_all mecanism
| private: | ||
| // RAII helper: cleans up a partial capture if the user-supplied callable | ||
| // throws between start- and end-capture. | ||
| struct capture_guard_t { |
There was a problem hiding this comment.
Very cool mechanism, I love it
There was a problem hiding this comment.
Credits to Claude :)
| capture_guard_t guard{stream.value()}; | ||
|
|
||
| RAFT_CUDA_TRY(cudaStreamBeginCapture(stream.value(), cudaStreamCaptureModeThreadLocal)); | ||
| guard.capture_active = true; |
There was a problem hiding this comment.
I guess there is a tiny risk here if there is an error exactly between begin capture and settings capture_active = true
There was a problem hiding this comment.
I think since the error is contained in this thread, it should be okay. If there is an exception in another thread and is uncaught, it will abort the process anyway.
hlinsen
left a comment
There was a problem hiding this comment.
Very clean solution, thanks @akifcorduk!
| // swaps the primal/dual ping-pong buffers between outer pdlp iterations — so the captured | ||
| // graph's baked-in pointers depend on `total_pdlp_iterations` parity, not on `should_major`. | ||
| // Use a dedicated ping-pong cache per branch and key each on `total_pdlp_iterations` so each | ||
| // (branch, parity) pair maps to its own cached executable. |
There was a problem hiding this comment.
That's untrue, there is no update_solution when using reflected. update_solution is only called in the take_adaptive_step while reflected is used in the take_constant_step
|
/ok to test 035150d |
|
/merge |
This is a permanent fix to cuda graph capture issue. We add a small RAII wrapper around
cudaStreamBeginCapture/cudaStreamEndCapturethat detectscudaErrorStreamCaptureInvalidatedat EndCapture time, drops the (never-issued) partial graph, re-runs the callable eagerly so the current iteration still produces correct results, and stays uninitialized so the next call retries capture. One extra eager pass instead of a crash.Closing the other PR:#1250
Fixes #1185