Commit 1a8a17b
Add ACE-Step pipeline for text-to-music generation (#13095)
* Add ACE-Step pipeline for text-to-music generation
Rebased on origin/main from the original pr-13095 branch (3 commits squashed).
- AceStepDiTModel: Diffusion Transformer with RoPE, GQA, sliding window,
AdaLN timestep conditioning, and cross-attention.
- AceStepConditionEncoder: fuses text / lyric / timbre into a single
cross-attention sequence.
- AceStepPipeline: text2music / cover / repaint / extract / lego / complete.
- Conversion script for the original checkpoint layout.
- Docs + tests.
* Fix ACE-Step pipeline audio quality and auto-detect turbo/base/sft variants
The PR's original inference produced low-quality audio on turbo because the
pipeline (a) mangled the SFT prompt format, (b) applied classifier-free guidance
with the wrong unconditional embedding (empty-string encoded vs. the learned
`null_condition_emb`), and (c) hardcoded turbo defaults even when loading a
base/SFT checkpoint.
Changes:
* Converter preserves `null_condition_emb` (stored under the condition encoder)
and propagates `is_turbo`/`model_version` into the transformer config so the
pipeline can route per-variant defaults.
* `AceStepConditionEncoder` registers `null_condition_emb` as a learned
parameter matching the original module.
* Pipeline auto-detects variant via `is_turbo`/`model_version` and picks
defaults that match `acestep/inference.py`:
* turbo: steps=8, shift=3.0, guidance_scale=1.0 (no CFG)
* base/SFT: steps=27, shift=1.0, guidance_scale=7.0
* Base/SFT timestep schedule uses the linear+shift transform from
`acestep/models/base/modeling_acestep_v15_base.py`; turbo still uses the
hardcoded 8-step `SHIFT_TIMESTEPS` table.
* CFG reuses the learned `null_condition_emb` and batches the
conditional+unconditional forwards into a single transformer call.
* `SFT_GEN_PROMPT` matches the newline layout in `acestep/constants.py` so the
text encoder sees the same prompt distribution it was trained on.
DiT parity vs. the original ACE-Step 1.5 turbo DiT is bit-identical
(max_abs=0.0 in fp32 eager/SDPA across 4 seed/shape cases) — see
scripts/dit_parity_test.py.
* Add ACE-Step parity test scripts
Two developer-facing parity harnesses live under scripts/:
* dit_parity_test.py — loads the same converted turbo weights into the
original AceStepDiTModel and the diffusers AceStepDiTModel, drives
identical (hidden_states, timestep, timestep_r, encoder_hidden_states,
context_latents) inputs, and asserts max-abs-diff ≤ 1e-5 in fp32
eager/SDPA. Currently passes bit-identical (max_abs=0) across four
shape/seed cases including batched + odd-length paths.
* audio_parity_jieyue.py — full end-to-end audio parity. Given the same
JSON example, runs both the original ACE-Step 1.5 pipeline and the
diffusers AceStepPipeline at matched seed/precision (bf16 + FA2 by
default) and saves side-by-side .wav files for listening verification.
Supports text2music / cover / repaint × turbo / base / sft via a
--matrix mode that writes 18 wavs named
{variant}_{task}_{official,diffusers}.wav.
* Route SFT parity to acestep-v15-sft checkpoint
On jieyue the release tree has a dedicated SFT checkpoint at
checkpoints/acestep-v15-sft with its own modeling_acestep_v15_base.py
shipped under acestep/models/sft/. Point the SFT row of the parity matrix
at that checkpoint / module so we're testing the actual SFT weights, not
the plain base ones.
* audio_parity_jieyue: fix doubled 'acestep-' in cache path; --converted-root flag
Previously the converted-pipeline cache dir was
`/tmp/acestep-<variant>-diffusers` but <variant> already starts with
"acestep-", giving `/tmp/acestep-acestep-v15-turbo-diffusers`. Drop the
prefix.
On jieyue the overlay rootfs (including /tmp) only has a few GB free; a
full turbo conversion needs ~5 GB per variant. Add --converted-root (env
ACESTEP_CONVERTED_ROOT) so the cache can live on vepfs.
* audio_parity_jieyue: two-phase matrix bootstraps cover/repaint from text2music
The ACE-Step release bundle on jieyue doesn't ship sample .wav/.mp3
files, so matrix mode had no default --src-audio and would skip
cover/repaint entirely. Run text2music first for every variant, then
reuse the TURBO official text2music output as the shared source for the
cover/repaint rows. Users can still override with --src-audio.
* audio_parity_jieyue: seed the diffusers generator on the pipeline device
The ORIGINAL ACE-Step pipeline seeds on the execution device
(`torch.Generator(device=device).manual_seed(seed)`), i.e. the CUDA RNG
stream when running on GPU. Previously the parity harness seeded the
diffusers side with a CPU generator, so even though the seed integer
matched, the two sides drew different noise from the outset and the
final outputs were essentially uncorrelated. Use the execution-device
generator on both sides for a fair comparison.
* Fix ACE-Step pipeline: switch to APG guidance + peak normalization
Two issues found after the first jieyue audio parity run:
1. The original base/SFT pipeline uses APG (Adaptive Projected Guidance,
acestep/models/common/apg_guidance.py) with a stateful momentum
buffer and norm/projection steps — NOT vanilla CFG. Using vanilla CFG
produced uncorrelated outputs vs. the reference (pearson ~0.0 on
20 s samples); this PR ports `_apg_forward` + `_APGMomentumBuffer`
and plugs them into the denoising loop when `guidance_scale > 1`.
Momentum is instantiated once per pipeline call (persists across
denoising steps) to match the reference semantics.
2. The post-VAE "anti-clipping normalization" in this pipeline was
`audio /= std * 5` with a `std<1 -> std=1` guard. The original
post-processing in
acestep/core/generation/handler/generate_music_decode.py is simple
peak normalization: `if audio.abs().max() > 1: audio /= peak`. The
std-based proxy both (a) let clips with peak < 1 leak through
unchanged (over-quiet) and (b) failed to bring clipping peaks to
exactly 1 in a bunch of base/SFT cases (observed max=1.000, std=0.200
repeatedly in the first parity run). Switch to peak normalization on
both sides.
Tested via scripts/audio_parity_jieyue.py on A800; re-run pending to
confirm the base/SFT correlation improvements.
* Fix ACE-Step chunk mask values to match the original pipeline
The DiT receives `context_latents = concat(src_latents, chunk_mask)` on the
channel dim, and was trained with chunk_mask values drawn from the three
sentinels documented in acestep/inference.py:
2.0 -> model-decided (default for text2music / cover / full-generation)
1.0 -> keep this latent frame from src_latents (repaint preserved region)
0.0 -> explicitly repaint this frame (only inside the repaint window)
Previously _build_chunk_mask returned all-1.0 for text2music (and cover /
lego), and an inverted 0/1 mask for repaint (1 inside the window, 0 outside).
Either case puts context_latents out of distribution. Switch text2music /
cover to the 2.0 sentinel and flip the repaint mask so it's 1.0 outside /
0.0 inside. Update the repaint src_latents zero-out to multiply by the new
mask (was `1 - chunk_mask`) so the zero region still lines up with the
repaint window.
* Add direct invoker for ACE-Step generate_music (ground truth)
Our earlier audio_parity_jieyue.py reconstructs the original pipeline by
calling AceStepConditionGenerationModel.generate_audio() directly, which
silently skips a lot of the real handler plumbing (conditioning masks,
silence-latent tiling, cover/repaint pre-processing, etc.). That made the
'official' wavs we saved sound wrong — flat, drone-like, not music.
This new script calls acestep.inference.generate_music end-to-end through
the real AceStepHandler, with LM + CoT explicitly disabled so we still have
a deterministic comparison. Use it to generate the ground-truth 'official'
wav for a given JSON example, then separately run the diffusers pipeline
with the same inputs and diff the two.
* run_official_generate_music: call initialize_service to bind a DiT variant
AceStepHandler() is a shell — you have to call handler.initialize_service(
project_root=..., config_path=..., device=..., use_flash_attention=..., ...)
before generate_music will work. Mirror what cli.py does at the equivalent
spot (around cli.py:1400).
* Fix silence-reference for ACE-Step timbre encoder
The root cause for the flat / drone-like outputs I was seeing (including
in my 'official' reconstruction): when no reference_audio is provided the
pipeline was feeding literal zeros to the timbre encoder. The real
handler feeds a slice of the learned `silence_latent` tensor.
The handler also transposes silence_latent on load (see
acestep/core/generation/handler/init_service_loader.py:214:
self.silence_latent = torch.load(...).transpose(1, 2)
) converting [1, 64, 15000] -> [1, 15000, 64] so that
`silence_latent[:, :750, :]` yields the expected [1, 750, 64] shape.
Changes:
* Converter: load silence_latent.pt, transpose to [1, T, C], bake into
the condition_encoder safetensors under key `silence_latent`.
(Also keeps the raw .pt file at the pipeline root for debugging.)
* AceStepConditionEncoder: register `silence_latent` as a persistent
buffer so from_pretrained loads it alongside the trained weights.
* Pipeline: when reference_audio is None, slice
`condition_encoder.silence_latent[:, :timbre_fix_frame, :]` and
broadcast across the batch instead of zeros. Emits a loud warning
(and falls back to zeros) if the buffer is all-zero — that means the
checkpoint was produced by an older converter and should be rebuilt.
* audio_parity_jieyue.py: the reference path now matches the handler's
silence-latent slicing.
Without this fix, every variant/task combo produced drone-like audio
even when my numeric DiT-forward parity claimed they were identical.
* Fix three more ACE-Step pipeline bugs I found by dumping real inputs
Instrumented the live generate_audio call in the real ACE-Step handler and
observed the exact tensors it sees — my diffusers pipeline was wrong in
three independent ways:
1. src_latents for text2music should be silence_latent tiled to
latent_length, NOT zeros. The handler fills no-target cases from
silence_latent_tiled (observed std=0.96). Zeros are OOD for the DiT
context_latents concat and produce drone-like outputs.
2. chunk_mask values cap at 1.0 (not 2.0). The handler starts with a
bool tensor (True inside the generate span, False outside); the
chunk_mask_modes=auto -> 2.0 override does NOT take effect because
the underlying tensor is bool, so setting entry = 2.0 casts to True.
After the later .to(dtype) float cast, the DiT sees 1.0/0.0 — exactly
what I observed in the captured tensor (unique values = [True]).
3. Default shift is 1.0 for ALL variants, including turbo. I was
defaulting turbo to shift=3.0 which picks a different SHIFT_TIMESTEPS
table (the 8-step schedule is keyed by shift, not variant).
Also:
* Added _silence_latent_tiled() helper that slices / tiles the learned
silence_latent (now loaded as a buffer on the condition encoder) to
the requested latent length.
* Repaint path now substitutes silence_latent (not raw zeros) inside
the repaint window — matches conditioning_masks.py.
* audio_parity_jieyue.py mirrors the same src/chunk/shift choices on
its 'original' leg for apples-to-apples parity once the buggy
reconstruction is removed from the picture.
* Add peak+loudness post-normalization to AceStepPipeline
The real pipeline normalizes audio in two stages (see
acestep/audio_utils.py:72 normalize_audio + generate_music_decode.py):
1. if peak > 1: audio /= peak (anti-clip)
2. audio *= target_amp / peak (target_amp = 10 ** (-1/20) ~ 0.891)
Step 2 is loudness normalization to -1 dBFS. Without it diffusers outputs
had peak=1.0 vs the real 0.891 — same music content (pearson was ~0.86
already), just 1.12x louder. Add step 2 after the existing anti-clip step.
* Match acestep/inference.py inference_steps=8 for ALL variants
GenerationParams.inference_steps default is 8 — turbo AND base/SFT. I had
base/SFT defaulting to 27 here, so every base/SFT parity run was comparing
a 27-step diffusers trajectory against an 8-step real trajectory. Different
number of denoising steps means different audio even at fixed seed.
This likely explains the lower base/SFT correlation in my earlier jieyue
runs (turbo was 0.86, base/SFT were 0.32-0.34). Aligning step counts
should bring base/SFT closer to turbo parity.
* Address PR #13095 review: rename classes + reuse diffusers primitives
Response to dg845's PR comments batch 1+2. DiT parity harness still bit-identical
(max_abs=0 on fp32 / SDPA across 4 shape cases).
Transformer file:
* Rename AceStepDiTModel -> AceStepTransformer1DModel (alias kept).
* Rename AceStepDiTLayer -> AceStepTransformerBlock (alias kept).
* Inherit AttentionMixin + CacheMixin on the DiT model.
* Swap in diffusers.models.normalization.RMSNorm for the hand-rolled
AceStepRMSNorm (weight-key-compatible).
* Swap the hand-rolled rotary embedding + apply_rotary for diffusers'
get_1d_rotary_pos_embed + apply_rotary_emb (use_real_unbind_dim=-2 to
match the cat-half convention ACE-Step inherits from Qwen3).
* Use get_timestep_embedding with flip_sin_to_cos=True — keeps the
(cos, sin) ordering of the original sinusoidal. State-dict-compatible.
* Drop max_position_embeddings arg from DiT config (RoPE computes freqs
per call based on seq_len); converter drops it.
* Gradient-checkpoint call now takes just the layer module (matches the
Flux2 idiom).
Pipeline modeling file (pipelines/ace_step/modeling_ace_step.py):
* Moved _pack_sequences + AceStepEncoderLayer here — they aren't used
by the DiT, so they shouldn't live in the transformer file.
* AceStepLyricEncoder + AceStepTimbreEncoder set
_supports_gradient_checkpointing = True and wrap encoder-layer calls
through the checkpointing func when enabled.
* Use diffusers RMSNorm + the RoPE helper from the transformer file
(shared single implementation).
Converter (scripts/convert_ace_step_to_diffusers.py):
* model_index.json now carries AceStepTransformer1DModel.
* Drop max_position_embeddings / use_sliding_window from the emitted
configs.
No numerical regressions: scripts/dit_parity_test.py PASSES with
max_abs=0.0 on fp32/SDPA across short, long, batched, and
padding-path shape variants.
* Address PR #13095 review: pipeline polish + converter HF-hub support
Response to dg845 review comments on the pipeline side. DiT parity still
bit-identical (max_abs=0 across 4 shape cases).
Pipeline (pipelines/ace_step/pipeline_ace_step.py):
* Add `sample_rate` + `latents_per_second` properties sourced from the
VAE config so the pipeline no longer hardcodes 48000 / 25 / 1920.
Propagates through prepare_latents, chunk_mask window math, and the
audio-duration round-trip.
* Add `do_classifier_free_guidance` property (matches LTX2 et al.).
* Add `check_inputs(...)` called from `__call__` before allocating noise.
Validates prompt type, lyrics type, task_type, step count, guidance
scale, shift, cfg interval bounds and repaint window ordering.
* Add `callback_on_step_end` + `callback_on_step_end_tensor_inputs` —
the modern callback form. The legacy `callback` / `callback_steps`
pair is kept for back-compat. Setting `pipe._interrupt = True` inside
the callback stops the loop early.
* Expose `encode_audio(audio)` as a public helper that wraps the tiled
VAE encode + (B, T, D) transpose the pipeline performs internally.
Converter (scripts/convert_ace_step_to_diffusers.py):
* Accept a Hugging Face Hub repo id for `--checkpoint_dir`; resolves it
via `huggingface_hub.snapshot_download` when the argument isn't a
local path.
Exports:
* Register `AceStepTransformer1DModel` in the top-level __init__,
models/__init__, models/transformers/__init__, and dummy_pt_objects so
`from diffusers import AceStepTransformer1DModel` works and the
pipeline loader resolves the new class name from model_index.json.
Deferred for a follow-up (commented inline in the PR): full
`Attention + AttnProcessor + dispatch_attention_fn` refactor and
`FlowMatchEulerDiscreteScheduler` migration — both would benefit from a
dedicated parity re-run and review.
* Fix stale ACE-Step 1.0-era docs / class names in the 1.5 integration
Docs and docstrings still carried a mix of 1.0 paper title, non-existent
`ACE-Step/ACE-Step-v1-5-turbo` hub id, `shift=3.0` turbo default, and
the old `AceStepDiTModel` class name. Cleaned up to match the actual
1.5 release:
* pipelines/ace_step.md: correct citation title ("ACE-Step 1.5: Pushing
the Boundaries of Open-Source Music Generation"), correct repo
(`ace-step/ACE-Step-1.5`), new variants table with real HF ids
(`Ace-Step1.5` / `acestep-v15-base` / `acestep-v15-sft`) and their
per-variant step/CFG defaults, drop the wrong `shift=3.0` tip.
* models/ace_step_transformer.md: page renamed to
`AceStepTransformer1DModel` with a short 1.5-specific description;
`AceStepDiTModel` noted as a backwards-compat alias.
* pipeline_ace_step.py: import, docstring, `Args`, and `__init__`
annotation reference `AceStepTransformer1DModel`; example model id
now `ACE-Step/Ace-Step1.5`; `_variant_defaults` docstring and the
`__call__` variant-fallback comment no longer claim `shift=3.0` /
`27 steps` — real defaults are 8 steps / shift=1.0 across all
variants, guidance=1.0 (turbo) vs 7.0 (base+sft).
* Address PR #13095 review: VAE tiling on AutoencoderOobleck + Timesteps class
Two more deferred review threads from dg845 addressed:
* Move tiled encode/decode onto AutoencoderOobleck
(#13095 (comment)).
AutoencoderOobleck now carries `use_tiling` + `tile_sample_min_length` /
`tile_sample_overlap` / `tile_latent_min_length` / `tile_latent_overlap`
attributes and private `_tiled_encode` / `_tiled_decode` methods; the
existing `encode` / `_decode` dispatch to them when tiling is enabled and
the input exceeds the threshold. `AutoencoderMixin.enable_tiling()` is
already inherited.
AceStepPipeline's private `_tiled_encode` / `_tiled_decode` and the
`use_tiled_decode` `__call__` arg are gone; `__init__` now calls
`self.vae.enable_tiling()` so the long-audio memory behaviour is preserved
by default. Users can opt out with `pipe.vae.disable_tiling()`.
Note: the VAE-side tiling concatenates encoder features (h) and samples
the posterior once, instead of the old per-tile `.sample()` calls. This
is the standard diffusers pattern; numerically differs only in the
structure of the noise across tile boundaries.
* Use the Timesteps nn.Module for the sinusoid
(#13095 (comment)).
`AceStepTimestepEmbedding` wraps `Timesteps(in_channels, flip_sin_to_cos=
True, downscale_freq_shift=0)` instead of calling `get_timestep_embedding`
directly — reviewer asked for the Module form.
* Address PR #13095 review: refactor AceStepAttention to Attention + AttnProcessor
Splits the monolithic AceStepAttention into the diffusers standard
Attention + AttnProcessor layout:
- AceStepAttention (torch.nn.Module, AttentionModuleMixin) holds the
to_q/to_k/to_v/to_out projections and norm_q/norm_k RMSNorms.
- AceStepAttnProcessor2_0 runs the attention dispatch through
dispatch_attention_fn so users can pick flash / sage / native backends
via model.set_attention_backend(...) or the attention_backend context
manager.
GQA (Q has 16 heads / K,V have 8) is preserved by passing enable_gqa=True
to dispatch_attention_fn instead of repeat_interleave; fusion is disabled
(_supports_qkv_fusion = False) because Q and K,V have different output
sizes.
The converter is updated to rename the six attention sub-keys
(q_proj -> to_q, k_proj -> to_k, v_proj -> to_v, o_proj -> to_out.0,
q_norm -> norm_q, k_norm -> norm_k) on both the DiT decoder path and the
condition encoder path, since AceStepLyricEncoder / AceStepTimbreEncoder
share the same AceStepAttention class.
Addresses review comments r2785433213 and r2785450463.
* Address PR #13095 review: migrate to FlowMatchEulerDiscreteScheduler
Replace the hand-rolled flow-matching Euler loop with
`FlowMatchEulerDiscreteScheduler`. ACE-Step still computes its own shifted /
turbo sigma schedule via `_get_timestep_schedule`, but now passes it to
`scheduler.set_timesteps(sigmas=...)` and delegates the ODE step to
`scheduler.step()`. The scheduler is configured with `num_train_timesteps=1`
and `shift=1.0` so `scheduler.timesteps` stays in `[0, 1]` (the convention the
DiT was trained on) and the scheduler doesn't re-shift already-shifted sigmas.
The scheduler's appended terminal `sigma=0` reproduces the old loop's
final-step "project to x0" case exactly: `prev = x + (0 - t_curr) * v`.
Parity on jieyue (seed=42, bf16 + flash-attn, turbo text2music, 8 steps):
waveform Pearson = 0.999999
spectral Pearson = 1.000000
max |diff| = 2.5e-3 (fp32 step-math vs previous bf16 step-math)
fp32 Euler-loop A/B against the hand-rolled path: max |diff| = 3.6e-7.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* Address PR #13095 review: move DiT tests + drop stale test kwargs
- Move the DiT transformer tests out of the pipeline test file into a new
tests/models/transformers/test_models_transformer_ace_step.py that follows
the standard BaseModelTesterConfig + ModelTesterMixin scaffold (matches
test_models_transformer_longcat_audio_dit.py).
- Drop `max_position_embeddings` from the remaining AceStepDiTModel and
AceStepConditionEncoder test fixtures — neither constructor accepts that
argument anymore.
- Drop `use_sliding_window` from the same fixtures — also no longer a
constructor argument (the actual `sliding_window` int kwarg is kept).
- Wire `FlowMatchEulerDiscreteScheduler(num_train_timesteps=1, shift=1.0)`
into `get_dummy_components()` now that the pipeline requires it.
Resolves #13095 (comment),
r3115664850, r3115673059, r3115676580, r3115680700.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* Address PR #13095 review from dg845 (2026-04-23)
Fixes 5 review threads + style:
1. Converter now builds `AceStepPipeline` in memory and calls
`save_pretrained`. Previously the hand-written `model_index.json` was
missing the `scheduler` entry — fresh converter output couldn't be loaded
by `AceStepPipeline.from_pretrained` (r3127767785). This also makes the
converter robust to future `__init__` signature changes.
2. `latent_length` uses `math.ceil(...)` instead of `int(...)` so non-integer
products (e.g. `latents_per_second=2.0, audio_duration=0.4 → 0.8`) round up
to `1` instead of truncating to `0` and crashing shape checks (r3127790939).
3. Add `_callback_tensor_inputs = ["latents"]` on `AceStepPipeline` so the
standard diffusers callback tests pick up the right tensor (r3127795954).
4. `AceStepConditionEncoder.silence_latent` no longer hard-codes the channel
dim to 64. The placeholder buffer now uses the `timbre_hidden_dim`
constructor argument, so smaller test configs with `timbre_hidden_dim != 64`
load without shape errors (r3127812932).
5. Revert `self.vae.enable_tiling()` from `AceStepPipeline.__init__`. Users can
call `pipe.vae.enable_tiling()` themselves for long-form generation; that
matches the opt-in convention used by the rest of diffusers (r3127777296).
6. `ruff check --fix` + `ruff format` over all ACE-Step sources (the style fix
dg845 asked for via `@bot /style`).
Also: converter now accepts sharded `model.safetensors.index.json` layouts
alongside the single-file `model.safetensors`, so the 5B XL turbo variant
converts without a pre-processing step.
Parity on jieyue (seed=42, bf16 + flash-attn, turbo text2music 160s, fresh
converter output loaded via `from_pretrained`):
waveform Pearson = 0.999954
spectral Pearson = 0.999977
max |a-b| bf16 = 4.3e-02 (dominated by the VAE tiling default flip)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* Address PR #13095 review from yiyixuxu (2026-04-23)
Code-level (22 threads):
1. Delete 3 dev/parity scripts (`scripts/audio_parity_jieyue.py`,
`scripts/dit_parity_test.py`, `scripts/run_official_generate_music.py`)
that shouldn't have been committed.
2. Rename `AutoencoderOobleck._encode_one` → `_encode` to match the convention
used by other diffusers VAEs.
3. Delete the hard-coded `SHIFT_TIMESTEPS` / `VALID_SHIFTS` table in
`pipeline_ace_step.py`: the per-shift turbo schedules are recovered
exactly by `linspace(1, 0, N+1)[:-1]` plus the flow-match shift formula
that the non-turbo branch already uses, so a single code path covers both.
4. Drop the backwards-compat `AceStepDiTModel` / `AceStepDiTLayer` aliases
and every reference (top-level `__init__`, `models/__init__`,
`transformers/__init__`, dummy objects, tests, docs toctree, model card).
`AceStepTransformer1DModel` is the only exported name now.
5. Remove the unused `attention_mask` / `encoder_attention_mask` args from
`AceStepTransformer1DModel.forward`; the model rebuilds its masks from
the sequence shape and never consumed them.
6. In the DiT forward and both encoders, pass `None` instead of an all-zero
`full_attn_mask` / `encoder_4d_mask` to non-sliding attention layers — SDPA
dispatches to a faster kernel when the mask is None.
7. Inline the shared `_run_encoder_layers` helper directly into
`AceStepLyricEncoder.forward` / `AceStepTimbreEncoder.forward` so layer
calls are visible at the forward boundary (diffusers style).
8. Move `is_turbo` / `sample_rate` / `latents_per_second` from `@property`s
that re-read module configs each call to cached attributes populated in
`__init__` (Flux2-style), with a default-ACE-Step fallback when
`self.vae` is offloaded. Drop the now-unused `SAMPLE_RATE = 48000`
module-level constant and the three property definitions.
9. Warn + coerce `guidance_scale` to 1.0 on turbo (guidance-distilled)
checkpoints, following `pipeline_flux2_klein`. Prevents over-guided
audio when users forward their base/sft CFG settings to a turbo pipe.
10. Remove the `logger.warning(...)` paths that triggered on
`silence_latent` missing/zero — those only fired for author-side
unconverted checkpoints and tests; end users always load converted
weights where the buffer is baked in.
11. Drop the redundant `with torch.no_grad():` wrappers inside
`encode_prompt` — the pipeline's `__call__` runs under `torch.no_grad`
already.
12. Strip "reviewer comment on PR #13095" attribution comments from three
docstrings (here and everywhere).
Parity on jieyue (seed=42, bf16 + flash-attn, XL turbo 160s text2music):
waveform Pearson = 0.9747
spectral Pearson = 0.9895
The shift comes from full-attention layers switching `attn_mask=0_tensor` →
`attn_mask=None`, which dispatches to a different SDPA kernel on bf16. The
two outputs are algebraically equivalent for fp32 eager; on bf16+FA the
delta is dominated by kernel-level ULPs, well within the sampler-noise
band (ear-check on the 160s example confirms no audible regression).
Still open — AudioTokenizer/Detokenizer (deferred) + APG guider follow-up
(dims differ from `diffusers.guiders.adaptive_projected_guidance`, not a
drop-in; worth a separate PR).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* Address ACE-Step audio token and APG review
* Fix ACE-Step docs CI
* Address ACE-Step pipeline cleanup review
* Fix ACE-Step flash attention sliding windows
* Add ACE-Step callback properties
* Address ACE-Step final review comments
---------
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>1 parent 303c1d8 commit 1a8a17b
20 files changed
Lines changed: 4156 additions & 11 deletions
File tree
- docs/source/en
- api
- models
- pipelines
- scripts
- src/diffusers
- guiders
- models
- autoencoders
- transformers
- pipelines
- ace_step
- utils
- tests
- models/transformers
- pipelines/ace_step
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
324 | 324 | | |
325 | 325 | | |
326 | 326 | | |
| 327 | + | |
| 328 | + | |
327 | 329 | | |
328 | 330 | | |
329 | 331 | | |
| |||
488 | 490 | | |
489 | 491 | | |
490 | 492 | | |
| 493 | + | |
| 494 | + | |
491 | 495 | | |
492 | 496 | | |
493 | 497 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
0 commit comments