2525from diffusers .models .attention_dispatch import AttentionBackendName , _AttentionBackendRegistry
2626
2727from ...testing_utils import (
28+ is_attention ,
2829 is_context_parallel ,
30+ is_kernels_available ,
2931 require_torch_multi_accelerator ,
3032 torch_device ,
3133)
34+ from .utils import _maybe_cast_to_bf16
3235
3336
3437# Device configuration mapping
@@ -47,7 +50,9 @@ def _find_free_port():
4750 return port
4851
4952
50- def _context_parallel_worker (rank , world_size , master_port , model_class , init_dict , cp_dict , inputs_dict , return_dict ):
53+ def _context_parallel_worker (
54+ rank , world_size , master_port , model_class , init_dict , cp_dict , inputs_dict , return_dict , attention_backend = None
55+ ):
5156 """Worker function for context parallel testing."""
5257 try :
5358 # Set up distributed environment
@@ -73,9 +78,16 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
7378 model .to (device )
7479 model .eval ()
7580
81+ # Cast as needed.
82+ model , inputs_dict = _maybe_cast_to_bf16 (attention_backend , model , inputs_dict )
83+
7684 # Move inputs to device
7785 inputs_on_device = {k : v .to (device ) if isinstance (v , torch .Tensor ) else v for k , v in inputs_dict .items ()}
7886
87+ # Enable attention backend
88+ if attention_backend :
89+ model .set_attention_backend (attention_backend )
90+
7991 # Enable context parallelism
8092 cp_config = ContextParallelConfig (** cp_dict )
8193 model .enable_parallelism (config = cp_config )
@@ -356,3 +368,77 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names)
356368 assert return_dict .get ("status" ) == "success" , (
357369 f"Custom mesh context parallel inference failed: { return_dict .get ('error' , 'Unknown error' )} "
358370 )
371+
372+
373+ @is_attention
374+ @is_context_parallel
375+ @require_torch_multi_accelerator
376+ class ContextParallelAttentionBackendsTesterMixin :
377+ @pytest .mark .parametrize ("cp_type" , ["ulysses_degree" , "ring_degree" ])
378+ @pytest .mark .parametrize (
379+ "attention_backend" ,
380+ [
381+ "native" ,
382+ pytest .param (
383+ "flash_hub" ,
384+ marks = pytest .mark .skipif (not is_kernels_available (), reason = "`kernels` is not available." ),
385+ ),
386+ pytest .param (
387+ "_flash_3_hub" ,
388+ marks = pytest .mark .skipif (not is_kernels_available (), reason = "`kernels` is not available." ),
389+ ),
390+ ],
391+ )
392+ @pytest .mark .parametrize ("ulysses_anything" , [True , False ])
393+ @torch .no_grad ()
394+ def test_context_parallel_attn_backend_inference (self , cp_type , attention_backend , ulysses_anything ):
395+ if not torch .distributed .is_available ():
396+ pytest .skip ("torch.distributed is not available." )
397+
398+ if getattr (self .model_class , "_cp_plan" , None ) is None :
399+ pytest .skip ("Model does not have a _cp_plan defined for context parallel inference." )
400+
401+ if cp_type == "ring_degree" :
402+ if attention_backend == AttentionBackendName .NATIVE :
403+ pytest .skip ("Skipping test because ring isn't supported with native attention backend." )
404+
405+ if ulysses_anything and "ulysses" not in cp_type :
406+ pytest .skip ("Skipping test as ulysses anything needs the ulysses degree set." )
407+
408+ world_size = 2
409+ init_dict = self .get_init_dict ()
410+ inputs_dict = self .get_dummy_inputs ()
411+
412+ # Move all tensors to CPU for multiprocessing
413+ inputs_dict = {k : v .cpu () if isinstance (v , torch .Tensor ) else v for k , v in inputs_dict .items ()}
414+ cp_dict = {cp_type : world_size }
415+ if ulysses_anything :
416+ cp_dict .update ({"ulysses_anything" : ulysses_anything })
417+
418+ # Find a free port for distributed communication
419+ master_port = _find_free_port ()
420+
421+ # Use multiprocessing manager for cross-process communication
422+ manager = mp .Manager ()
423+ return_dict = manager .dict ()
424+
425+ # Spawn worker processes
426+ mp .spawn (
427+ _context_parallel_worker ,
428+ args = (
429+ world_size ,
430+ master_port ,
431+ self .model_class ,
432+ init_dict ,
433+ cp_dict ,
434+ inputs_dict ,
435+ return_dict ,
436+ attention_backend ,
437+ ),
438+ nprocs = world_size ,
439+ join = True ,
440+ )
441+
442+ assert return_dict .get ("status" ) == "success" , (
443+ f"Context parallel inference failed: { return_dict .get ('error' , 'Unknown error' )} "
444+ )
0 commit comments