Skip to content

Commit 0fff459

Browse files
authored
Fix ErnieImagePipeline pre-computed prompt_embeds + num_images_per_prompt shape mismatch (#13532)
Fix ErnieImagePipeline pre-computed prompt_embeds + num_images_per_prompt When a user passes pre-computed `prompt_embeds` (or `negative_prompt_embeds`) alongside `num_images_per_prompt > 1`, `ErnieImagePipeline.__call__` did not replicate the provided embeddings — the embeds list kept its original length (one per prompt) while the latents were allocated with `total_batch_size = batch_size * num_images_per_prompt`: text_hiddens = prompt_embeds # length = batch_size (NOT replicated) ... latents = randn_tensor((total_batch_size, ...)) # batch * N in shape In the denoise loop `text_bth.shape[0]` then mismatches `latent_model_input.shape[0]`, so the transformer call: pred = self.transformer( hidden_states=latent_model_input, # (batch*N*2, ...) under CFG text_bth=text_bth, # (batch*2, ...) ... ) fails with a shape mismatch inside the attention block. The standard "pre-compute embeds once, generate N variants" usage pattern is broken. `encode_prompt` already performs this replication internally (`for _ in range(num_images_per_prompt): text_hiddens.append(hidden)` at lines 158-160), so the non-embed path is unaffected — this only impacts callers of the documented `prompt_embeds` / `negative_prompt_embeds` arguments. Mirror the replication logic in the pre-embed branches so both paths yield a `text_hiddens` list of length `batch_size * num_images_per_prompt`.
1 parent 2173c55 commit 0fff459

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,14 +286,14 @@ def __call__(
286286

287287
# [Phase 2] Text encoding
288288
if prompt_embeds is not None:
289-
text_hiddens = prompt_embeds
289+
text_hiddens = [h for h in prompt_embeds for _ in range(num_images_per_prompt)]
290290
else:
291291
text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt)
292292

293293
# CFG with negative prompt
294294
if self.do_classifier_free_guidance:
295295
if negative_prompt_embeds is not None:
296-
uncond_text_hiddens = negative_prompt_embeds
296+
uncond_text_hiddens = [h for h in negative_prompt_embeds for _ in range(num_images_per_prompt)]
297297
else:
298298
uncond_text_hiddens = self.encode_prompt(negative_prompt, device, num_images_per_prompt)
299299

0 commit comments

Comments
 (0)