Skip to content

Commit a76953c

Browse files
committed
feat: neuron-specific changes in the pipeline
1 parent 0c51734 commit a76953c

5 files changed

Lines changed: 33 additions & 7 deletions

File tree

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -855,10 +855,11 @@ def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float |
855855
# This would be a good case for the `match` statement (Python 3.10+)
856856
is_mps = sample.device.type == "mps"
857857
is_npu = sample.device.type == "npu"
858+
is_neuron = sample.device.type == "neuron"
858859
if isinstance(timestep, float):
859-
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
860+
dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64
860861
else:
861-
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
862+
dtype = torch.int32 if (is_mps or is_npu or is_neuron) else torch.int64
862863
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
863864
elif len(timesteps.shape) == 0:
864865
timesteps = timesteps[None].to(sample.device)

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2273,6 +2273,7 @@ def enable_neuron_compile(
22732273
"""
22742274
requires_backends(self, "torch_neuronx")
22752275
import torch_neuronx # noqa: F401 — registers neuron backend
2276+
from torch_neuronx.neuron_dynamo_backend import set_model_name
22762277

22772278
if cache_dir is not None:
22782279
os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir
@@ -2286,6 +2287,7 @@ def enable_neuron_compile(
22862287
component = getattr(self, name, None)
22872288
if isinstance(component, torch.nn.Module) and not is_compiled_module(component):
22882289
logger.info(f"Compiling {name} with backend='neuron'")
2290+
set_model_name(name)
22892291
setattr(self, name, torch.compile(component, backend="neuron", fullgraph=fullgraph))
22902292

22912293
def neuron_warmup(self, *args, **kwargs) -> None:

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,11 @@ def __call__(
10921092
)
10931093

10941094
# 4. Prepare timesteps
1095-
if XLA_AVAILABLE:
1095+
# Keep timesteps on CPU for XLA (TPU) and Neuron: both use lazy/XLA execution where
1096+
# dynamic-shape ops like .nonzero() and .item() inside scheduler.index_for_timestep()
1097+
# are incompatible with static-graph compilation.
1098+
is_neuron_device = hasattr(device, "type") and device.type == "neuron"
1099+
if XLA_AVAILABLE or is_neuron_device:
10961100
timestep_device = "cpu"
10971101
else:
10981102
timestep_device = device
@@ -1195,15 +1199,23 @@ def __call__(
11951199
# expand the latents if we are doing classifier free guidance
11961200
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
11971201

1198-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1202+
# For Neuron: scale_model_input on CPU to avoid XLA ops outside the compiled UNet region.
1203+
# index_for_timestep() uses .nonzero()/.item() which are incompatible with static graphs.
1204+
if is_neuron_device:
1205+
latent_model_input = self.scheduler.scale_model_input(latent_model_input.to("cpu"), t).to(device)
1206+
else:
1207+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
11991208

12001209
# predict the noise residual
12011210
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
12021211
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
12031212
added_cond_kwargs["image_embeds"] = image_embeds
1213+
# For Neuron: pre-cast timestep to float32 on device. Neuron XLA does not support
1214+
# int64 ops; the compiled UNet graph requires a float32 timestep input on-device.
1215+
t_unet = t.to(torch.float32).to(device) if is_neuron_device else t
12041216
noise_pred = self.unet(
12051217
latent_model_input,
1206-
t,
1218+
t_unet,
12071219
encoder_hidden_states=prompt_embeds,
12081220
timestep_cond=timestep_cond,
12091221
cross_attention_kwargs=self.cross_attention_kwargs,
@@ -1222,7 +1234,13 @@ def __call__(
12221234

12231235
# compute the previous noisy sample x_t -> x_t-1
12241236
latents_dtype = latents.dtype
1225-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1237+
# For Neuron: scheduler.step on CPU to keep scheduler arithmetic off the XLA device.
1238+
if is_neuron_device:
1239+
latents = self.scheduler.step(
1240+
noise_pred.to("cpu"), t, latents.to("cpu"), **extra_step_kwargs, return_dict=False
1241+
)[0].to(device)
1242+
else:
1243+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
12261244
if latents.dtype != latents_dtype:
12271245
if torch.backends.mps.is_available():
12281246
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272

src/diffusers/utils/import_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,10 @@ def is_av_available():
584584
"""
585585

586586

587+
TORCH_NEURONX_IMPORT_ERROR = """
588+
{0} requires the torch_neuronx library (AWS Neuron SDK) but it was not found in your environment. Please install it following the AWS Neuron documentation: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/
589+
"""
590+
587591
BACKENDS_MAPPING = OrderedDict(
588592
[
589593
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
@@ -614,6 +618,7 @@ def is_av_available():
614618
("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)),
615619
("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)),
616620
("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
621+
("torch_neuronx", (is_torch_neuronx_available, TORCH_NEURONX_IMPORT_ERROR)),
617622
]
618623
)
619624

src/diffusers/utils/torch_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
"xpu": getattr(torch.xpu, "synchronize", None),
9494
"cpu": None,
9595
"mps": None,
96-
"neuron": None,
96+
"neuron": lambda: getattr(getattr(torch, "neuron", None), "synchronize", lambda: None)(),
9797
"default": None,
9898
}
9999
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

0 commit comments

Comments
 (0)