Skip to content

feat(diffusion): add gradient checkpointing for memory optimization#3503

Open
jashshah999 wants to merge 1 commit intohuggingface:mainfrom
jashshah999:feat/diffusion-gradient-checkpointing
Open

feat(diffusion): add gradient checkpointing for memory optimization#3503
jashshah999 wants to merge 1 commit intohuggingface:mainfrom
jashshah999:feat/diffusion-gradient-checkpointing

Conversation

@jashshah999
Copy link
Copy Markdown
Contributor

Summary

Adds gradient checkpointing to DiffusionPolicy's UNet (encoder, mid, and decoder residual blocks). This trades compute for memory, allowing training with larger batch sizes or higher-resolution inputs on memory-constrained GPUs.

Currently only Pi0 and XVLA have gradient checkpointing. This extends it to DiffusionPolicy as called for in the 0.6.0 roadmap (item 3.3).

Usage

lerobot-train \
  --policy.type=diffusion \
  --policy.gradient_checkpointing=true \
  --dataset.repo_id=lerobot/pusht

What changed

  • Added gradient_checkpointing: bool = False to DiffusionConfig
  • Wrapped UNet encoder/mid/decoder residual blocks with torch.utils.checkpoint.checkpoint when enabled and training
  • Uses use_reentrant=False for compatibility with torch.compile

Test plan

  • Training without gradient_checkpointing works as before
  • Training with gradient_checkpointing=true runs without error
  • Peak GPU memory is reduced with gradient_checkpointing enabled
  • Training loss converges similarly with and without checkpointing

0.6.0 roadmap item 3.3.

Add gradient_checkpointing config option to DiffusionPolicy. When
enabled, wraps the UNet encoder, mid, and decoder residual blocks
with torch.utils.checkpoint.checkpoint to trade compute for memory.

Allows training with larger batch sizes or higher-resolution inputs
on memory-constrained GPUs. Disabled by default.

Usage: --policy.gradient_checkpointing=true

Part of the 0.6.0 roadmap item 3.3 (gradient checkpointing for all
policies).
@github-actions github-actions Bot added the policies Items related to robot policies label May 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

policies Items related to robot policies

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant