|
| 1 | +"""FLUX inference on TPU using PyTorch/XLA SPMD. |
| 2 | +
|
| 3 | +Uses SPMD to shard the transformer across multiple TPU chips, enabling |
| 4 | +inference on devices where the model doesn't fit on a single chip (e.g., v5e). |
| 5 | +The VAE is loaded on CPU at startup, moved to XLA for decode, then moved back. |
| 6 | +""" |
| 7 | + |
| 8 | +from argparse import ArgumentParser |
| 9 | +from pathlib import Path |
| 10 | +from time import perf_counter |
| 11 | + |
| 12 | +import numpy as np |
| 13 | +import structlog |
| 14 | +import torch |
| 15 | +import torch_xla.core.xla_model as xm |
| 16 | +import torch_xla.debug.metrics as met |
| 17 | +import torch_xla.debug.profiler as xp |
| 18 | +import torch_xla.distributed.spmd as xs |
| 19 | +import torch_xla.runtime as xr |
| 20 | +from torch_xla.experimental.custom_kernel import FlashAttention |
| 21 | + |
| 22 | +from diffusers import AutoencoderKL, FluxPipeline |
| 23 | + |
| 24 | + |
| 25 | +cache_path = Path("/tmp/data/compiler_cache_eXp") |
| 26 | +cache_path.mkdir(parents=True, exist_ok=True) |
| 27 | +xr.initialize_cache(str(cache_path), readonly=False) |
| 28 | +xr.use_spmd() |
| 29 | + |
| 30 | +logger = structlog.get_logger() |
| 31 | +metrics_filepath = "/tmp/metrics_report.txt" |
| 32 | +VAE_SCALE_FACTOR = 8 |
| 33 | + |
| 34 | + |
| 35 | +def _vae_decode(latents, vae, height, width, device): |
| 36 | + """Move VAE to XLA, decode latents, move VAE back to CPU.""" |
| 37 | + vae.to(device) |
| 38 | + latents = FluxPipeline._unpack_latents(latents, height, width, VAE_SCALE_FACTOR) |
| 39 | + latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor |
| 40 | + with torch.no_grad(): |
| 41 | + image = vae.decode(latents, return_dict=False)[0] |
| 42 | + vae.to("cpu") |
| 43 | + return image |
| 44 | + |
| 45 | + |
| 46 | +def main(args): |
| 47 | + # --- SPMD mesh: 4-way model parallel to fit transformer + VAE on v5e chips --- |
| 48 | + num_devices = xr.global_runtime_device_count() |
| 49 | + if num_devices >= 4: |
| 50 | + mesh = xs.Mesh(np.arange(num_devices), (num_devices // 4, 4), ("data", "model")) |
| 51 | + else: |
| 52 | + NotImplementedError |
| 53 | + xs.set_global_mesh(mesh) |
| 54 | + logger.info(f"SPMD mesh: {mesh.mesh_shape}, axes: {mesh.axis_names}, devices: {num_devices}") |
| 55 | + |
| 56 | + # --- Profiler --- |
| 57 | + profile_path = Path("/tmp/data/profiler_out_eXp") |
| 58 | + profile_path.mkdir(parents=True, exist_ok=True) |
| 59 | + profiler_port = 9012 |
| 60 | + profile_duration = args.profile_duration |
| 61 | + if args.profile: |
| 62 | + logger.info(f"starting profiler on port {profiler_port}") |
| 63 | + _ = xp.start_server(profiler_port) |
| 64 | + |
| 65 | + device = xm.xla_device() |
| 66 | + |
| 67 | + # --- Checkpoint --- |
| 68 | + if args.schnell: |
| 69 | + ckpt_id = "black-forest-labs/FLUX.1-schnell" |
| 70 | + else: |
| 71 | + ckpt_id = "black-forest-labs/FLUX.1-dev" |
| 72 | + |
| 73 | + # --- Text encoding (CPU) --- |
| 74 | + prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side" |
| 75 | + logger.info("encoding prompt on CPU...") |
| 76 | + text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to("cpu") |
| 77 | + with torch.no_grad(): |
| 78 | + prompt_embeds, pooled_prompt_embeds, _ = text_pipe.encode_prompt( |
| 79 | + prompt=prompt, prompt_2=None, max_sequence_length=512 |
| 80 | + ) |
| 81 | + image_processor = text_pipe.image_processor |
| 82 | + del text_pipe |
| 83 | + |
| 84 | + # --- Load VAE on CPU (moved to XLA only for decode) --- |
| 85 | + logger.info("loading VAE on CPU...") |
| 86 | + vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16) |
| 87 | + |
| 88 | + # --- Load transformer and shard --- |
| 89 | + logger.info(f"loading flux transformer from {ckpt_id}") |
| 90 | + flux_pipe = FluxPipeline.from_pretrained( |
| 91 | + ckpt_id, |
| 92 | + text_encoder=None, |
| 93 | + tokenizer=None, |
| 94 | + text_encoder_2=None, |
| 95 | + tokenizer_2=None, |
| 96 | + vae=None, |
| 97 | + torch_dtype=torch.bfloat16, |
| 98 | + ).to(device) |
| 99 | + |
| 100 | + for name, param in flux_pipe.transformer.named_parameters(): |
| 101 | + if param.dim() >= 2: |
| 102 | + spec = [None] * param.dim() |
| 103 | + largest_dim = max(range(param.dim()), key=lambda d: param.shape[d]) |
| 104 | + spec[largest_dim] = "model" |
| 105 | + xs.mark_sharding(param, mesh, tuple(spec)) |
| 106 | + |
| 107 | + flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True) |
| 108 | + FlashAttention.DEFAULT_BLOCK_SIZES = { |
| 109 | + "block_q": 1536, |
| 110 | + "block_k_major": 1536, |
| 111 | + "block_k": 1536, |
| 112 | + "block_b": 1536, |
| 113 | + "block_q_major_dkv": 1536, |
| 114 | + "block_k_major_dkv": 1536, |
| 115 | + "block_q_dkv": 1536, |
| 116 | + "block_k_dkv": 1536, |
| 117 | + "block_q_dq": 1536, |
| 118 | + "block_k_dq": 1536, |
| 119 | + "block_k_major_dq": 1536, |
| 120 | + } |
| 121 | + |
| 122 | + width = args.width |
| 123 | + height = args.height |
| 124 | + guidance = args.guidance |
| 125 | + n_steps = 4 if args.schnell else 28 |
| 126 | + |
| 127 | + prompt_embeds = prompt_embeds.to(device) |
| 128 | + pooled_prompt_embeds = pooled_prompt_embeds.to(device) |
| 129 | + xs.mark_sharding(prompt_embeds, mesh, ("data", None, None)) |
| 130 | + xs.mark_sharding(pooled_prompt_embeds, mesh, ("data", None)) |
| 131 | + |
| 132 | + # --- Compilation run --- |
| 133 | + logger.info("starting compilation run...") |
| 134 | + ts = perf_counter() |
| 135 | + latents = flux_pipe( |
| 136 | + prompt_embeds=prompt_embeds, |
| 137 | + pooled_prompt_embeds=pooled_prompt_embeds, |
| 138 | + num_inference_steps=28, |
| 139 | + guidance_scale=guidance, |
| 140 | + height=height, |
| 141 | + width=width, |
| 142 | + output_type="latent", |
| 143 | + ).images |
| 144 | + image = _vae_decode(latents, vae, height, width, device) |
| 145 | + image = image_processor.postprocess(image)[0] |
| 146 | + logger.info(f"compilation took {perf_counter() - ts} sec.") |
| 147 | + image.save("/tmp/compile_out.png") |
| 148 | + |
| 149 | + # --- Inference loop --- |
| 150 | + seed = 4096 if args.seed is None else args.seed |
| 151 | + xm.set_rng_state(seed=seed, device=device) |
| 152 | + times = [] |
| 153 | + logger.info("starting inference run...") |
| 154 | + for _ in range(args.itters): |
| 155 | + ts = perf_counter() |
| 156 | + |
| 157 | + if args.profile: |
| 158 | + xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration) |
| 159 | + latents = flux_pipe( |
| 160 | + prompt_embeds=prompt_embeds, |
| 161 | + pooled_prompt_embeds=pooled_prompt_embeds, |
| 162 | + num_inference_steps=n_steps, |
| 163 | + guidance_scale=guidance, |
| 164 | + height=height, |
| 165 | + width=width, |
| 166 | + output_type="latent", |
| 167 | + ).images |
| 168 | + image = _vae_decode(latents, vae, height, width, device) |
| 169 | + image = image_processor.postprocess(image)[0] |
| 170 | + inference_time = perf_counter() - ts |
| 171 | + logger.info(f"inference time: {inference_time}") |
| 172 | + times.append(inference_time) |
| 173 | + |
| 174 | + logger.info(f"avg. inference over {args.itters} iterations took {sum(times) / len(times)} sec.") |
| 175 | + image.save("/tmp/inference_out.png") |
| 176 | + metrics_report = met.metrics_report() |
| 177 | + with open(metrics_filepath, "w+") as fout: |
| 178 | + fout.write(metrics_report) |
| 179 | + logger.info(f"saved metric information as {metrics_filepath}") |
| 180 | + |
| 181 | + |
| 182 | +if __name__ == "__main__": |
| 183 | + parser = ArgumentParser() |
| 184 | + parser.add_argument("--schnell", action="store_true", help="run flux schnell instead of dev") |
| 185 | + parser.add_argument("--width", type=int, default=1024, help="width of the image to generate") |
| 186 | + parser.add_argument("--height", type=int, default=1024, help="height of the image to generate") |
| 187 | + parser.add_argument("--guidance", type=float, default=3.5, help="guidance strength for dev") |
| 188 | + parser.add_argument("--seed", type=int, default=None, help="seed for inference") |
| 189 | + parser.add_argument("--profile", action="store_true", help="enable profiling") |
| 190 | + parser.add_argument("--profile-duration", type=int, default=10000, help="duration for profiling in msec.") |
| 191 | + parser.add_argument("--itters", type=int, default=15, help="items to run inference and get avg time in sec.") |
| 192 | + args = parser.parse_args() |
| 193 | + main(args) |
0 commit comments