Skip to content

Commit 2480388

Browse files
committed
tests: eager tests
1 parent a76953c commit 2480388

4 files changed

Lines changed: 33 additions & 64 deletions

File tree

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2249,64 +2249,6 @@ def _is_pipeline_device_mapped(self):
22492249

22502250
return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1
22512251

2252-
def enable_neuron_compile(
2253-
self,
2254-
model_names: Optional[List[str]] = None,
2255-
cache_dir: Optional[str] = None,
2256-
fullgraph: bool = True,
2257-
) -> None:
2258-
"""
2259-
Compiles the pipeline's nn.Module components with ``torch.compile(backend="neuron")``,
2260-
enabling whole-graph NEFF compilation for AWS Trainium/Inferentia.
2261-
2262-
The first forward call per component triggers neuronx-cc compilation (slow).
2263-
Use ``neuron_warmup()`` to trigger this explicitly before timed inference.
2264-
2265-
Args:
2266-
model_names (`List[str]`, *optional*):
2267-
Component names to compile. Defaults to all nn.Module components.
2268-
cache_dir (`str`, *optional*):
2269-
Path to persist compiled NEFFs across runs via ``TORCH_NEURONX_NEFF_CACHE_DIR``.
2270-
Skips recompilation on subsequent runs.
2271-
fullgraph (`bool`, defaults to `True`):
2272-
Disallow graph breaks (required for full-graph fusion).
2273-
"""
2274-
requires_backends(self, "torch_neuronx")
2275-
import torch_neuronx # noqa: F401 — registers neuron backend
2276-
from torch_neuronx.neuron_dynamo_backend import set_model_name
2277-
2278-
if cache_dir is not None:
2279-
os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir
2280-
2281-
if model_names is None:
2282-
model_names = [
2283-
name for name, comp in self.components.items() if isinstance(comp, torch.nn.Module)
2284-
]
2285-
2286-
for name in model_names:
2287-
component = getattr(self, name, None)
2288-
if isinstance(component, torch.nn.Module) and not is_compiled_module(component):
2289-
logger.info(f"Compiling {name} with backend='neuron'")
2290-
set_model_name(name)
2291-
setattr(self, name, torch.compile(component, backend="neuron", fullgraph=fullgraph))
2292-
2293-
def neuron_warmup(self, *args, **kwargs) -> None:
2294-
"""
2295-
Runs a single dummy forward pass through the pipeline to trigger neuronx-cc
2296-
compilation for all components (static-shape NEFF compilation).
2297-
2298-
This is equivalent to calling ``__call__`` with the same shapes but discards
2299-
the output. After warmup, subsequent calls reuse the compiled NEFFs and run fast.
2300-
2301-
Pass the same arguments you would use for real inference (height, width,
2302-
num_inference_steps, batch_size, etc.) so that the compiled shapes match.
2303-
"""
2304-
logger.info("Running Neuron warmup forward pass to trigger NEFF compilation...")
2305-
with torch.no_grad():
2306-
self(*args, **kwargs)
2307-
logger.info("Neuron warmup complete.")
2308-
2309-
23102252
class StableDiffusionMixin:
23112253
r"""
23122254
Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion)

src/diffusers/utils/testing_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
is_peft_available,
4747
is_timm_available,
4848
is_torch_available,
49+
is_torch_neuronx_available,
4950
is_torch_version,
5051
is_torchao_available,
5152
is_torchsde_available,
@@ -113,6 +114,8 @@
113114
torch_device = "cuda"
114115
elif torch.xpu.is_available():
115116
torch_device = "xpu"
117+
elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available():
118+
torch_device = torch.neuron.current_device()
116119
else:
117120
torch_device = "cpu"
118121
is_torch_higher_equal_than_1_12 = version.parse(

tests/pipelines/pixart_alpha/test_pixart.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
PixArtTransformer2DModel,
2929
)
3030

31+
from diffusers.utils.import_utils import is_torch_neuronx_available
32+
3133
from ...testing_utils import (
3234
backend_empty_cache,
3335
enable_full_determinism,
@@ -291,7 +293,9 @@ def test_pixart_1024(self):
291293
expected_slice = np.array([0.0742, 0.0835, 0.2114, 0.0295, 0.0784, 0.2361, 0.1738, 0.2251, 0.3589])
292294

293295
max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice)
294-
self.assertLessEqual(max_diff, 1e-4)
296+
# Neuron uses bfloat16 internally which has lower precision than float16 on CUDA
297+
atol = 1e-2 if is_torch_neuronx_available() else 1e-4
298+
self.assertLessEqual(max_diff, atol)
295299

296300
def test_pixart_512(self):
297301
generator = torch.Generator("cpu").manual_seed(0)
@@ -307,7 +311,9 @@ def test_pixart_512(self):
307311
expected_slice = np.array([0.3477, 0.3882, 0.4541, 0.3413, 0.3821, 0.4463, 0.4001, 0.4409, 0.4958])
308312

309313
max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice)
310-
self.assertLessEqual(max_diff, 1e-4)
314+
# Neuron uses bfloat16 internally which has lower precision than float16 on CUDA
315+
atol = 1e-2 if is_torch_neuronx_available() else 1e-4
316+
self.assertLessEqual(max_diff, atol)
311317

312318
def test_pixart_1024_without_resolution_binning(self):
313319
generator = torch.manual_seed(0)

tests/testing_utils.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
is_peft_available,
4646
is_timm_available,
4747
is_torch_available,
48+
is_torch_neuronx_available,
4849
is_torch_version,
4950
is_torchao_available,
5051
is_torchsde_available,
@@ -109,6 +110,8 @@
109110
torch_device = "cuda"
110111
elif torch.xpu.is_available():
111112
torch_device = "xpu"
113+
elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available():
114+
torch_device = torch.neuron.current_device()
112115
else:
113116
torch_device = "cpu"
114117
is_torch_higher_equal_than_1_12 = version.parse(
@@ -1427,6 +1430,15 @@ def _is_torch_fp64_available(device):
14271430
# Behaviour flags
14281431
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
14291432

1433+
# Neuron device key: torch.neuron.current_device() returns an int (e.g. 0).
1434+
# We capture it once at import time if torch_neuronx is available so we can add it
1435+
# to all dispatch tables using the same key that torch_device is set to.
1436+
_neuron_device = (
1437+
torch.neuron.current_device()
1438+
if (is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available())
1439+
else None
1440+
)
1441+
14301442
# Function definitions
14311443
BACKEND_EMPTY_CACHE = {
14321444
"cuda": torch.cuda.empty_cache,
@@ -1478,13 +1490,19 @@ def _is_torch_fp64_available(device):
14781490
"default": None,
14791491
}
14801492

1493+
if _neuron_device is not None:
1494+
BACKEND_EMPTY_CACHE[_neuron_device] = None
1495+
BACKEND_DEVICE_COUNT[_neuron_device] = torch.neuron.device_count
1496+
BACKEND_MANUAL_SEED[_neuron_device] = torch.manual_seed
1497+
BACKEND_RESET_PEAK_MEMORY_STATS[_neuron_device] = None
1498+
BACKEND_RESET_MAX_MEMORY_ALLOCATED[_neuron_device] = None
1499+
BACKEND_MAX_MEMORY_ALLOCATED[_neuron_device] = 0
1500+
BACKEND_SYNCHRONIZE[_neuron_device] = torch.neuron.synchronize
1501+
14811502

14821503
# This dispatches a defined function according to the accelerator from the function definitions.
14831504
def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable], *args, **kwargs):
1484-
if device not in dispatch_table:
1485-
return dispatch_table["default"](*args, **kwargs)
1486-
1487-
fn = dispatch_table[device]
1505+
fn = dispatch_table[device] if device in dispatch_table else dispatch_table["default"]
14881506

14891507
# Some device agnostic functions return values. Need to guard against 'None' instead at
14901508
# user level

0 commit comments

Comments
 (0)