Skip to content

from_pretrained orchestration + distributed save/load#45409

Merged
3outeille merged 12 commits intomoe-sequence-parallelfrom
orchestration-save-load
Apr 14, 2026
Merged

from_pretrained orchestration + distributed save/load#45409
3outeille merged 12 commits intomoe-sequence-parallelfrom
orchestration-save-load

Conversation

@3outeille
Copy link
Copy Markdown
Member

Summary

  • Full distributed_config integration in from_pretrained() — mesh creation, apply TP + FSDP, attach model.device_mesh
  • gather_full_state_dict() for streaming DTensor→full tensor saving (rank 0 only)
  • convert_strided_to_shard() / restore_strided_from_shard() for DCP compatibility with _StridedShard
  • save_optimizer() / load_optimizer() in distributed/utils.py
  • Rename apply_fsdp2apply_fully_shard_data_parallel
  • Trainer integration with distributed_config

Part of the distributed training API chain: #44989

Chain: main ← #44989 ← #44083 ← #44974 ← #45028 ← #45408 ← this PR

Review question

Does from_pretrained wire things up in the right order? Is save/load round-trip correct?

Test plan

  • End-to-end from_pretrained with distributed_config
  • gather_full_state_dict() roundtrip verification
  • save_optimizer() / load_optimizer() roundtrip
  • Run existing FSDP and TP mixin tests

- Add gather_full_state_dict() for DTensor→full tensor saving
- Add convert_strided_to_shard() / restore_strided_from_shard() for DCP
- Add _redistribute_dtensor() helper
- Full distributed_config integration in from_pretrained/save_pretrained
- Rename apply_fsdp2 → apply_fully_shard_data_parallel
- save_optimizer() / load_optimizer() in distributed/utils
- Trainer integration with distributed_config
- Updated FSDP and TP tests for new orchestration API
- DTensor shard-on-read test updates
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@3outeille 3outeille force-pushed the moe-sequence-parallel branch from e04c7d9 to 24ca327 Compare April 14, 2026 09:54
@3outeille 3outeille force-pushed the orchestration-save-load branch from 815b5b2 to 7361deb Compare April 14, 2026 09:55
@3outeille 3outeille force-pushed the moe-sequence-parallel branch from 24ca327 to 7f297e0 Compare April 14, 2026 13:44
# Conflicts:
#	src/transformers/distributed/utils.py
@3outeille 3outeille force-pushed the orchestration-save-load branch from 7361deb to 1ecc329 Compare April 14, 2026 13:45
@3outeille 3outeille merged commit bbf3ab6 into moe-sequence-parallel Apr 14, 2026
19 of 28 checks passed
@3outeille 3outeille deleted the orchestration-save-load branch April 14, 2026 16:10
3outeille added a commit that referenced this pull request Apr 14, 2026
* MoE expert parallelism + sequence parallelism

- Add PackedColwiseParallel for fused gate_up_proj weights
- Add MoEExpertsParallel with per-expert DTensor sharding
- Add PrepareModuleInputOutput for SP allgather/split hooks
- Add _AllReduceBackward for MoE routing weight gradients
- Extend TPStyle with moe_experts, packed_colwise, activation, module kinds
- _StridedShard handling in core_model_loading for interleaved weights
- MoE model configs: mixtral, deepseek_v3, qwen3 with SP plans
- DTensor rotary_pos_emb guard for mixtral

* Fix ruff linting and formatting

* Fix ruff formatting in core_model_loading.py

* Restore _IdentityOp accidentally removed in 25a1f48

The _IdentityOp class (added by PR #44983) was accidentally deleted
during the MoE expert parallelism work. It is needed by
finegrained_fp8.py and metal_quantization.py as a pass-through
reverse_op for dequantize operations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Backport new TP/FSDP API + fix DTensor imports in Copied-from models

* from_pretrained orchestration + distributed save/load (#45409)

* from_pretrained orchestration + save/load

- Add gather_full_state_dict() for DTensor→full tensor saving
- Add convert_strided_to_shard() / restore_strided_from_shard() for DCP
- Add _redistribute_dtensor() helper
- Full distributed_config integration in from_pretrained/save_pretrained
- Rename apply_fsdp2 → apply_fully_shard_data_parallel
- save_optimizer() / load_optimizer() in distributed/utils
- Trainer integration with distributed_config
- Updated FSDP and TP tests for new orchestration API
- DTensor shard-on-read test updates

* revert distributed utils

* eaaea

* all tests for core modeling are passing

* populate import from init for tp

* ruff

* ruff

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45409&sha=39bea2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants