Skip to content

Commit f7fd76a

Browse files
sayakpaulDN6
andauthored
[attention backends] fix ring CP for flash and flash 3 (#13182)
* tests: add cp backend and attention backend tests. * up * up * up * fix ring for flash and flash_3 * generate. * Apply suggestions from code review Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * up * up --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent dad80d7 commit f7fd76a

6 files changed

Lines changed: 139 additions & 10 deletions

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,9 +1914,12 @@ def forward(
19141914
out = out.to(torch.float32)
19151915
lse = lse.to(torch.float32)
19161916

1917-
# Refer to:
1918-
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
1919-
if is_torch_version("<", "2.9.0"):
1917+
# lse must be 4-D to broadcast with out (B, S, H, D).
1918+
# Some backends (e.g. cuDNN on torch>=2.9) already return a
1919+
# trailing-1 dim; others (e.g. flash-hub / native-flash) always
1920+
# return 3-D lse, so we add the dim here when needed.
1921+
# See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
1922+
if lse.ndim == 3:
19201923
lse = lse.unsqueeze(-1)
19211924
if prev_out is not None:
19221925
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
@@ -2203,10 +2206,11 @@ def _templated_unified_attention(
22032206
scatter_idx,
22042207
)
22052208
if return_lse:
2206-
# lse is of shape (B, S, H_LOCAL, 1)
2207-
# Refer to:
2208-
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
2209-
if is_torch_version("<", "2.9.0"):
2209+
# lse from TemplatedRingAttention is 3-D (B, S, H_LOCAL) after its
2210+
# final squeeze(-1). SeqAllToAllDim requires a 4-D input, so we add
2211+
# the trailing dim here and remove it after the collective.
2212+
# See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
2213+
if lse.ndim == 3:
22102214
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
22112215
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
22122216
lse = lse.squeeze(-1)

tests/models/testing_utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .ip_adapter import IPAdapterTesterMixin
1414
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
1515
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
16-
from .parallelism import ContextParallelTesterMixin
16+
from .parallelism import ContextParallelAttentionBackendsTesterMixin, ContextParallelTesterMixin
1717
from .quantization import (
1818
BitsAndBytesCompileTesterMixin,
1919
BitsAndBytesConfigMixin,
@@ -45,6 +45,7 @@
4545
"BitsAndBytesTesterMixin",
4646
"CacheTesterMixin",
4747
"ContextParallelTesterMixin",
48+
"ContextParallelAttentionBackendsTesterMixin",
4849
"CPUOffloadTesterMixin",
4950
"FasterCacheConfigMixin",
5051
"FasterCacheTesterMixin",

tests/models/testing_utils/parallelism.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@
2525
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
2626

2727
from ...testing_utils import (
28+
is_attention,
2829
is_context_parallel,
30+
is_kernels_available,
2931
require_torch_multi_accelerator,
3032
torch_device,
3133
)
34+
from .utils import _maybe_cast_to_bf16
3235

3336

3437
# Device configuration mapping
@@ -47,7 +50,9 @@ def _find_free_port():
4750
return port
4851

4952

50-
def _context_parallel_worker(rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict):
53+
def _context_parallel_worker(
54+
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict, attention_backend=None
55+
):
5156
"""Worker function for context parallel testing."""
5257
try:
5358
# Set up distributed environment
@@ -73,9 +78,16 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
7378
model.to(device)
7479
model.eval()
7580

81+
# Cast as needed.
82+
model, inputs_dict = _maybe_cast_to_bf16(attention_backend, model, inputs_dict)
83+
7684
# Move inputs to device
7785
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
7886

87+
# Enable attention backend
88+
if attention_backend:
89+
model.set_attention_backend(attention_backend)
90+
7991
# Enable context parallelism
8092
cp_config = ContextParallelConfig(**cp_dict)
8193
model.enable_parallelism(config=cp_config)
@@ -356,3 +368,77 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names)
356368
assert return_dict.get("status") == "success", (
357369
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
358370
)
371+
372+
373+
@is_attention
374+
@is_context_parallel
375+
@require_torch_multi_accelerator
376+
class ContextParallelAttentionBackendsTesterMixin:
377+
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"])
378+
@pytest.mark.parametrize(
379+
"attention_backend",
380+
[
381+
"native",
382+
pytest.param(
383+
"flash_hub",
384+
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
385+
),
386+
pytest.param(
387+
"_flash_3_hub",
388+
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
389+
),
390+
],
391+
)
392+
@pytest.mark.parametrize("ulysses_anything", [True, False])
393+
@torch.no_grad()
394+
def test_context_parallel_attn_backend_inference(self, cp_type, attention_backend, ulysses_anything):
395+
if not torch.distributed.is_available():
396+
pytest.skip("torch.distributed is not available.")
397+
398+
if getattr(self.model_class, "_cp_plan", None) is None:
399+
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
400+
401+
if cp_type == "ring_degree":
402+
if attention_backend == AttentionBackendName.NATIVE:
403+
pytest.skip("Skipping test because ring isn't supported with native attention backend.")
404+
405+
if ulysses_anything and "ulysses" not in cp_type:
406+
pytest.skip("Skipping test as ulysses anything needs the ulysses degree set.")
407+
408+
world_size = 2
409+
init_dict = self.get_init_dict()
410+
inputs_dict = self.get_dummy_inputs()
411+
412+
# Move all tensors to CPU for multiprocessing
413+
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
414+
cp_dict = {cp_type: world_size}
415+
if ulysses_anything:
416+
cp_dict.update({"ulysses_anything": ulysses_anything})
417+
418+
# Find a free port for distributed communication
419+
master_port = _find_free_port()
420+
421+
# Use multiprocessing manager for cross-process communication
422+
manager = mp.Manager()
423+
return_dict = manager.dict()
424+
425+
# Spawn worker processes
426+
mp.spawn(
427+
_context_parallel_worker,
428+
args=(
429+
world_size,
430+
master_port,
431+
self.model_class,
432+
init_dict,
433+
cp_dict,
434+
inputs_dict,
435+
return_dict,
436+
attention_backend,
437+
),
438+
nprocs=world_size,
439+
join=True,
440+
)
441+
442+
assert return_dict.get("status") == "success", (
443+
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
444+
)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
3+
from diffusers.models.attention_dispatch import AttentionBackendName
4+
5+
6+
_BF16_REQUIRED_BACKENDS = {
7+
AttentionBackendName._NATIVE_CUDNN,
8+
AttentionBackendName.FLASH_HUB,
9+
AttentionBackendName._FLASH_3_HUB,
10+
}
11+
12+
13+
def _maybe_cast_to_bf16(backend, model, inputs_dict):
14+
"""Cast model and floating-point inputs to bfloat16 when the backend requires it."""
15+
if not backend or backend not in _BF16_REQUIRED_BACKENDS:
16+
return model, inputs_dict
17+
model = model.to(dtype=torch.bfloat16)
18+
inputs_dict = {
19+
k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
20+
for k, v in inputs_dict.items()
21+
}
22+
return model, inputs_dict

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
BaseModelTesterConfig,
3030
BitsAndBytesCompileTesterMixin,
3131
BitsAndBytesTesterMixin,
32+
ContextParallelAttentionBackendsTesterMixin,
3233
ContextParallelTesterMixin,
3334
FasterCacheTesterMixin,
3435
FirstBlockCacheTesterMixin,
@@ -245,6 +246,12 @@ class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextPar
245246
"""Context Parallel inference tests for Flux Transformer"""
246247

247248

249+
class TestFluxTransformerContextParallelAttnBackends(
250+
FluxTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin
251+
):
252+
"""Context Parallel inference x attention backends tests for Flux Transformer"""
253+
254+
248255
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
249256
"""IP Adapter tests for Flux Transformer."""
250257

utils/generate_model_tests.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
# Other testers
7373
("SingleFileTesterMixin", "single_file"),
7474
("IPAdapterTesterMixin", "ip_adapter"),
75+
("ContextParallelAttentionBackendsTesterMixin", "cp_attn"),
7576
]
7677

7778

@@ -229,7 +230,14 @@ def determine_testers(model_info: dict, include_optional: list[str], imports: se
229230

230231
for tester, flag in OPTIONAL_TESTERS:
231232
if flag in include_optional:
232-
if tester not in testers:
233+
if tester == "ContextParallelAttentionBackendsTesterMixin":
234+
if (
235+
"cp_attn" in include_optional
236+
and "_cp_plan" in model_info["attributes"]
237+
and model_info["attributes"]["_cp_plan"] is not None
238+
):
239+
testers.append(tester)
240+
elif tester not in testers:
233241
testers.append(tester)
234242

235243
return testers
@@ -530,6 +538,7 @@ def main():
530538
"faster_cache",
531539
"single_file",
532540
"ip_adapter",
541+
"cp_attn",
533542
"all",
534543
],
535544
help="Optional testers to include",

0 commit comments

Comments
 (0)