Skip to content

Commit 50cb2db

Browse files
songh11sayakpaul
andauthored
feat: support ring attention with arbitrary KV sequence lengths (#13545)
* feat: support ring attention with arbitrary KV sequence lengths * fix: align ring_anything with ulysses_anything (size gather + unshard) * docs: document ring_anything mode * fix: merge hook branches, add ring_anything comment + guard * docs: address ring_anything review comments * docs: update ring_anything guidance * docs: refine ring_anything guidance per review * fix: address ring_anything style check --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 0fff459 commit 50cb2db

4 files changed

Lines changed: 198 additions & 16 deletions

File tree

docs/source/en/training/distributed_inference.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,47 @@ We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulys
371371

372372
From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention.
373373

374+
375+
### Ring Anything Attention
376+
377+
The default [Ring Attention](https://huggingface.co/papers/2310.01889) requires the sequence length of hidden states to be evenly divisible across the ring degree. [Ring Anything Attention](https://github.com/huggingface/diffusers/pull/13545#issuecomment-4302195582) is a variant of Ring Attention that supports arbitrary (non-evenly divisible) sequence lengths. It pads each rank's local KV to the global maximum sequence length, all-gathers the padded KV buffer, and slices back to each rank's true length before running attention.
378+
379+
Ring Anything Attention is not supported by Unified Attention. Set `ring_degree > 1` and `ring_anything=True` to enable Ring Anything Attention.
380+
381+
```py
382+
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ring_degree=2, ring_anything=True))
383+
```
384+
385+
> [!TIP]
386+
> Add the `gloo` backend to [init_process_group](https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) to avoid multiple forced CUDA syncs from H2D and D2H transfers.
387+
388+
```py
389+
import torch.distributed as dist
390+
391+
dist.init_process_group(backend="cpu:gloo,cuda:nccl")
392+
```
393+
394+
> [!NOTE]
395+
> Ring Anything Attention only currently supports inference and non-`None` attention masks aren't supported. `attn_mask` must be `None`.
396+
397+
See the FLUX.1-dev benchmarks below on a node of 4 RTX 4090 (48GB) GPUs.
398+
399+
| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)|
400+
|--------------------|------------------|-------------|------------------|------------|
401+
| ulysses | 259.07 | 3.86 | 33.83 | 1024x1024 |
402+
| ring | 338.98 | 2.95 | 33.83 | 1024x1024 |
403+
| unified_balanced | 321.54 | 3.11 | 33.83 | 1024x1024 |
404+
| ulysses_anything | 259.07 | 3.86 | 33.83 | 1024x1024 |
405+
| ring_anything | 340.14 | 2.94 | 33.83 | 1024x1024 |
406+
| ulysses | failed | failed | failed | 1008x1008 |
407+
| ring | failed | failed | failed | 1008x1008 |
408+
| unified_balanced | failed | failed | failed | 1008x1008 |
409+
| ulysses_anything | 253.16 | 3.95 | 33.75 | 1008x1008 |
410+
| ring_anything | 335.57 | 2.98 | 33.75 | 1008x1008 |
411+
412+
From the above table, Ring Anything Attention offers compatibility with arbitrary sequence lengths while maintaining performance comparable to the standard Ring Attention.
413+
For more details on the motivation and trade-offs for Ring Anything Attention, see [this comment](https://github.com/huggingface/diffusers/pull/13545#issuecomment-4304104462).
414+
374415
### parallel_config
375416

376417
Pass `parallel_config` during model initialization to enable context parallelism.

src/diffusers/hooks/context_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) ->
210210
)
211211
return x
212212
else:
213-
if self.parallel_config.ulysses_anything:
213+
if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything:
214214
return PartitionAnythingSharder.shard_anything(
215215
x, cp_input.split_dim, self.parallel_config._flattened_mesh
216216
)
@@ -239,7 +239,7 @@ def post_forward(self, module, output):
239239
for i, cpm in enumerate(self.metadata):
240240
if cpm is None:
241241
continue
242-
if self.parallel_config.ulysses_anything:
242+
if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything:
243243
output[i] = PartitionAnythingSharder.unshard_anything(
244244
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
245245
)

src/diffusers/models/_modeling_parallel.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ class ContextParallelConfig:
6464
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
6565
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
6666
`ring_degree` must be 1.
67+
ring_anything (`bool`, *optional*, defaults to `False`):
68+
Whether to enable "Ring Anything" mode, which supports arbitrary sequence lengths. When enabled,
69+
`ring_degree` must be greater than 1 and `ulysses_degree` must be 1.
6770
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
6871
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
6972
creating a new one. This is useful when combining context parallelism with other parallelism strategies
@@ -82,6 +85,8 @@ class ContextParallelConfig:
8285
# Whether to enable ulysses anything attention to support
8386
# any sequence lengths and any head numbers.
8487
ulysses_anything: bool = False
88+
# Whether to enable ring anything attention to support any sequence lengths.
89+
ring_anything: bool = False
8590

8691
_rank: int = None
8792
_world_size: int = None
@@ -114,6 +119,13 @@ def __post_init__(self):
114119
raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.")
115120
if self.ring_degree > 1:
116121
raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.")
122+
if self.ring_anything:
123+
if self.ring_degree == 1:
124+
raise ValueError("ring_degree must be greater than 1 for ring_anything to be enabled.")
125+
if self.ulysses_degree > 1:
126+
raise ValueError("ring_anything cannot be enabled when ulysses_degree > 1.")
127+
if self.ulysses_anything and self.ring_anything:
128+
raise ValueError("ulysses_anything and ring_anything cannot both be enabled.")
117129

118130
@property
119131
def mesh_shape(self) -> tuple[int, int]:

src/diffusers/models/attention_dispatch.py

Lines changed: 143 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2079,6 +2079,119 @@ def backward(
20792079
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
20802080

20812081

2082+
class TemplatedRingAnythingAttention(torch.autograd.Function):
2083+
@staticmethod
2084+
def forward(
2085+
ctx: torch.autograd.function.FunctionCtx,
2086+
query: torch.Tensor,
2087+
key: torch.Tensor,
2088+
value: torch.Tensor,
2089+
attn_mask: torch.Tensor | None,
2090+
dropout_p: float,
2091+
is_causal: bool,
2092+
scale: float | None,
2093+
enable_gqa: bool,
2094+
return_lse: bool,
2095+
forward_op,
2096+
backward_op,
2097+
_parallel_config: "ParallelConfig" | None = None,
2098+
):
2099+
# Ring attention for arbitrary sequence lengths.
2100+
if attn_mask is not None:
2101+
raise ValueError(
2102+
"TemplatedRingAnythingAttention does not support non-None attn_mask: "
2103+
"non-uniform sequence lengths across ranks make cross-rank mask slicing ambiguous."
2104+
)
2105+
ring_mesh = _parallel_config.context_parallel_config._ring_mesh
2106+
group = ring_mesh.get_group()
2107+
rank = _parallel_config.context_parallel_config._ring_local_rank
2108+
world_size = _parallel_config.context_parallel_config.ring_degree
2109+
next_rank = (rank + 1) % world_size
2110+
prev_out = prev_lse = None
2111+
2112+
ctx.forward_op = forward_op
2113+
ctx.backward_op = backward_op
2114+
ctx.q_shape = query.shape
2115+
ctx.kv_shape = key.shape
2116+
ctx._parallel_config = _parallel_config
2117+
2118+
kv_seq_len = key.shape[1] # local S_KV (may differ across ranks)
2119+
all_kv_seq_lens = gather_size_by_comm(kv_seq_len, group)
2120+
s_max = max(all_kv_seq_lens)
2121+
2122+
# Padding is applied on the sequence dimension (dim=1) at the end.
2123+
def pad_to_s_max(t: torch.Tensor) -> torch.Tensor:
2124+
pad_len = s_max - t.shape[1]
2125+
if pad_len == 0:
2126+
return t
2127+
pad_shape = (t.shape[0], pad_len, *t.shape[2:])
2128+
return torch.cat([t, t.new_zeros(pad_shape)], dim=1)
2129+
2130+
# Pad each local KV to the maximum local sequence length so all ranks can all-gather same-sized buffers.
2131+
key_padded = pad_to_s_max(key)
2132+
value_padded = pad_to_s_max(value)
2133+
2134+
kv_buffer = torch.cat([key_padded.flatten(), value_padded.flatten()]).contiguous()
2135+
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=group)
2136+
kv_buffer = kv_buffer.chunk(world_size)
2137+
2138+
# numel per-rank in the padded layout
2139+
kv_padded_numel = key_padded.numel()
2140+
2141+
for i in range(world_size):
2142+
if i > 0:
2143+
true_seq_len = all_kv_seq_lens[next_rank]
2144+
kv = kv_buffer[next_rank]
2145+
# Reshape to padded shape, then slice to true sequence length
2146+
key = kv[:kv_padded_numel].reshape_as(key_padded)[:, :true_seq_len]
2147+
value = kv[kv_padded_numel:].reshape_as(value_padded)[:, :true_seq_len]
2148+
next_rank = (next_rank + 1) % world_size
2149+
else:
2150+
# i == 0: use local (unpadded) key/value
2151+
key = key_padded[:, :kv_seq_len]
2152+
value = value_padded[:, :kv_seq_len]
2153+
2154+
out, lse = forward_op(
2155+
ctx,
2156+
query,
2157+
key,
2158+
value,
2159+
attn_mask,
2160+
dropout_p,
2161+
is_causal,
2162+
scale,
2163+
enable_gqa,
2164+
True,
2165+
_save_ctx=i == 0,
2166+
_parallel_config=_parallel_config,
2167+
)
2168+
2169+
if _parallel_config.context_parallel_config.convert_to_fp32:
2170+
out = out.to(torch.float32)
2171+
lse = lse.to(torch.float32)
2172+
2173+
if is_torch_version("<", "2.9.0"):
2174+
lse = lse.unsqueeze(-1)
2175+
if prev_out is not None:
2176+
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
2177+
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
2178+
prev_out = out
2179+
prev_lse = lse
2180+
2181+
out = out.to(query.dtype)
2182+
lse = lse.squeeze(-1)
2183+
2184+
return (out, lse) if return_lse else out
2185+
2186+
@staticmethod
2187+
def backward(
2188+
ctx: torch.autograd.function.FunctionCtx,
2189+
grad_out: torch.Tensor,
2190+
*args,
2191+
):
2192+
raise NotImplementedError("Backward pass for Ring Anything Attention in diffusers is not implemented yet.")
2193+
2194+
20822195
class TemplatedUlyssesAnythingAttention(torch.autograd.Function):
20832196
@staticmethod
20842197
def forward(
@@ -2258,20 +2371,36 @@ def _templated_context_parallel_attention(
22582371
_parallel_config,
22592372
)
22602373
elif _parallel_config.context_parallel_config.ring_degree > 1:
2261-
return TemplatedRingAttention.apply(
2262-
query,
2263-
key,
2264-
value,
2265-
attn_mask,
2266-
dropout_p,
2267-
is_causal,
2268-
scale,
2269-
enable_gqa,
2270-
return_lse,
2271-
forward_op,
2272-
backward_op,
2273-
_parallel_config,
2274-
)
2374+
if _parallel_config.context_parallel_config.ring_anything:
2375+
return TemplatedRingAnythingAttention.apply(
2376+
query,
2377+
key,
2378+
value,
2379+
attn_mask,
2380+
dropout_p,
2381+
is_causal,
2382+
scale,
2383+
enable_gqa,
2384+
return_lse,
2385+
forward_op,
2386+
backward_op,
2387+
_parallel_config,
2388+
)
2389+
else:
2390+
return TemplatedRingAttention.apply(
2391+
query,
2392+
key,
2393+
value,
2394+
attn_mask,
2395+
dropout_p,
2396+
is_causal,
2397+
scale,
2398+
enable_gqa,
2399+
return_lse,
2400+
forward_op,
2401+
backward_op,
2402+
_parallel_config,
2403+
)
22752404
elif _parallel_config.context_parallel_config.ulysses_degree > 1:
22762405
if _parallel_config.context_parallel_config.ulysses_anything:
22772406
# For Any sequence lengths and Any head num support

0 commit comments

Comments
 (0)