Skip to content

Feat: onnx export cli#3521

Open
tsuu-abj wants to merge 4 commits intohuggingface:mainfrom
tsuu-abj:feat/onnx-export-cli
Open

Feat: onnx export cli#3521
tsuu-abj wants to merge 4 commits intohuggingface:mainfrom
tsuu-abj:feat/onnx-export-cli

Conversation

@tsuu-abj
Copy link
Copy Markdown

@tsuu-abj tsuu-abj commented May 6, 2026

Title

feat(export): add lerobot-export CLI for ONNX/TensorRT edge deployment

Summary / Motivation

Adds a lerobot-export CLI that exports trained LeRobot policies to ONNX (and
optionally TensorRT) for edge deployment on devices such as Jetson Orin Nano/NX.
Users running trained policies on Jetson-class devices need an optimized
inference format; ONNX + TensorRT FP16 provides significant speedup over
PyTorch on SM_87 GPUs without requiring the Python training stack at inference
time.

Related issues

What changed

  • New CLI lerobot-export (entry point lerobot.scripts.lerobot_export)
    supporting --format=onnx (default) and --format=tensorrt.
  • New lerobot.export package with:
    • ExportSpec dataclass + register_export_wrapper plugin registry with
      auto-discovery by naming convention (policies/<type>/export_<type>.py).
    • Reusable adapter primitives: DictBatchAdapter (Pattern A) and
      IterativeDenoisingAdapter ABC (Pattern B) in export/adapters/.
    • Three ONNX exporter backends: dynamo, legacy, and auto (tries dynamo,
      falls back to legacy).
    • validate_onnx() — post-export parity check (max_abs_error, cos_sim,
      torch.allclose). Supports --validation-trials=N for random-input trials.
    • export_to_tensorrt() via trtexec subprocess with engine cache and
      hardware guards (FP8/INT8 on SM_87 raises; FP16 on SM<80 warns).
    • save_normalization_stats() + NormalizedWrapper (opt-in
      --fold-normalization bakes stats as ONNX constants).
  • Per-policy adapters co-located with each policy module:
    • policies/act/export_act.py — ACT via DictBatchAdapter
    • policies/diffusion/export_diffusion.py — UNet-only and full DDIM loop
      (--diffusion-mode=ddim-N)
    • policies/vqbet/export_vqbet.py — VQ-BeT via DictBatchAdapter
  • Unsupported policies (SAC, TDMPC, PI0, SmoLVLA) raise NotImplementedError
    with a concrete extension guide instead of silently producing wrong output.
  • New optional dependency group export (onnx, onnxruntime,
    onnxscript) in pyproject.toml; _onnx_available / _onnxruntime_available
    added to import_utils.py.

No breaking changes to existing APIs.

How was this tested (or how to run locally)

Unit tests (31 passed, 1 skipped on macOS/MPS):

uv sync --locked --extra export --extra test
uv run pytest tests/test_export.py tests/policies/test_export_act.py \
    tests/policies/test_export_vqbet.py tests/policies/test_export_diffusion.py -sv

End-to-end validation against real Hub checkpoints (macOS/CPU, opset 18,
--validation-trials=5):

# ACT
uv run lerobot-export \
  --policy.path=lerobot/act_aloha_sim_insertion_human \
  --output_path=./outputs/verify_act --device=cpu --exporter=legacy \
  --validation_trials=5
# => worst max_abs_error=2.09e-06, min cos_sim=1.000000, allclose=True

# Diffusion UNet (batch_size=4, dynamic)
uv run lerobot-export \
  --policy.path=lerobot/diffusion_pusht \
  --output_path=./outputs/verify_diffusion --device=cpu \
  --validation_trials=5
# => worst max_abs_error=9.06e-06, min cos_sim=1.000000, allclose=True

TensorRT engine builds require CUDA + trtexec; the hardware-guard error paths
are covered by tests/test_export.py::test_export_to_tensorrt_raises_without_cuda.

Checklist (required before merge)

  • Linting/formatting run (pre-commit run -a)
  • All tests pass locally (pytest)
  • Documentation updated
  • CI is green
  • Community Review: I have reviewed another contributor's open PR and linked it here: # (insert PR number/link)

Reviewer notes

  • ACT + dynamo batch>1 is a known limitation: even --exporter=dynamo
    produces a fixed batch_size=1 ONNX because ACT.forward allocates
    torch.zeros([batch_size, latent_dim]) inside the model, causing
    torch.export to specialize on the concrete value. Fixing this requires
    modifying ACT.forward itself and is out of scope for this PR.
  • export/adapters/ contains the two reusable primitives intended to keep
    future per-policy adapters thin. Feedback on the API surface is welcome.
  • TensorRT path is tested only via error-path unit tests (no CUDA on the
    development host); a GPU CI job is a follow-up.
  • Older Hub checkpoints (e.g. lerobot/act_aloha_sim_insertion_human) do not
    ship policy_preprocessor.json, so normalization_stats.json is written as
    {} with a warning — this is expected and logged clearly.

tsuu and others added 4 commits May 6, 2026 19:07
Introduces lerobot/export/ as a new top-level module and the
lerobot-export CLI entry point for exporting trained LeRobot policies
to ONNX and TensorRT for edge/embedded deployment.

Core components:
- export/core.py: ExportConfig dataclass, main export() orchestrator
- export/onnx_export.py: torch.export (dynamo) + torch.onnx fallback paths
- export/tensorrt_export.py: TensorRT engine compilation
- export/normalization.py: norm/denorm ONNX nodes via public processor API
- export/sample_inputs.py: sample-input construction from processor metadata
- export/validation.py: round-trip numerical correctness checks

Adds onnxruntime, onnx, and tensorrt as optional extras in pyproject.toml.
Moves ONNX wrapper logic out of a monolithic export/wrappers.py into
per-policy export modules (policies/act/export_act.py,
policies/diffusion/export_diffusion.py). The core dispatch layer
auto-discovers make_<type>_export_wrapper by name, so new policy support
does not require editing the central export module.
Extracts two reusable adapter primitives into export/adapters/:
- DictBatchAdapter: converts dict-of-tensors policy I/O to positional args
- IterativeDenoisingAdapter: wraps diffusion-style policies for single-pass export

Adds VQBET export support and extends validation.py with --num-validation-trials
for statistical correctness checks against real Hub checkpoints.
Replaces ExportConfig.diffusion_mode with free-form policy_options: dict[str, str]
so future policy types pass export parameters without modifying the CLI dataclass.

Adds make_<type>_export_artifacts auto-discovery for per-policy auxiliary ONNX files.
Fixes onnx.checker.check_model to accept a path string for models larger than 2 GiB.
@github-actions github-actions Bot added policies Items related to robot policies tests Problems with test coverage, failures, or improvements to testing labels May 6, 2026
@rylinjames
Copy link
Copy Markdown

hey, just flagging that VLA ONNX/TRT export for pi0, pi0.5, smolvla, and gr00t is already working in reflex-vla (github.com/FastCrest/reflex-vla) if it's useful reference for the PR.

the VLA export is the hard part. smolvla in particular took a bunch of patches to get right (broken kv cache wiring, wrong sinusoidal embedding, 5D vs 4D image dims from the AutoProcessor). we also bake the denoising loop unrolled into the monolithic ONNX so the TRT engine is a single graph. might save some time vs figuring it out from scratch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

policies Items related to robot policies tests Problems with test coverage, failures, or improvements to testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature: ONNX / TensorRT export for trained policies

2 participants