2626from ...callbacks import MultiPipelineCallbacks , PipelineCallback
2727from ...image_processor import PipelineImageInput
2828from ...loaders import SanaLoraLoaderMixin
29- from ...models import AutoencoderDC , AutoencoderKLWan , SanaVideoTransformer3DModel
29+ from ...models import AutoencoderDC , AutoencoderKLLTX2Video , AutoencoderKLWan , SanaVideoTransformer3DModel
3030from ...schedulers import FlowMatchEulerDiscreteScheduler
3131from ...utils import (
3232 BACKENDS_MAPPING ,
@@ -184,7 +184,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
184184 The tokenizer used to tokenize the prompt.
185185 text_encoder ([`Gemma2PreTrainedModel`]):
186186 Text encoder model to encode the input prompts.
187- vae ([`AutoencoderKLWan` or `AutoencoderDCAEV `]):
187+ vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video `]):
188188 Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
189189 transformer ([`SanaVideoTransformer3DModel`]):
190190 Conditional Transformer to denoise the input latents.
@@ -203,7 +203,7 @@ def __init__(
203203 self ,
204204 tokenizer : GemmaTokenizer | GemmaTokenizerFast ,
205205 text_encoder : Gemma2PreTrainedModel ,
206- vae : AutoencoderDC | AutoencoderKLWan ,
206+ vae : AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan ,
207207 transformer : SanaVideoTransformer3DModel ,
208208 scheduler : FlowMatchEulerDiscreteScheduler ,
209209 ):
@@ -213,8 +213,19 @@ def __init__(
213213 tokenizer = tokenizer , text_encoder = text_encoder , vae = vae , transformer = transformer , scheduler = scheduler
214214 )
215215
216- self .vae_scale_factor_temporal = self .vae .config .scale_factor_temporal if getattr (self , "vae" , None ) else 4
217- self .vae_scale_factor_spatial = self .vae .config .scale_factor_spatial if getattr (self , "vae" , None ) else 8
216+ if getattr (self , "vae" , None ):
217+ if isinstance (self .vae , AutoencoderKLLTX2Video ):
218+ self .vae_scale_factor_temporal = self .vae .config .temporal_compression_ratio
219+ self .vae_scale_factor_spatial = self .vae .config .spatial_compression_ratio
220+ elif isinstance (self .vae , (AutoencoderDC , AutoencoderKLWan )):
221+ self .vae_scale_factor_temporal = self .vae .config .scale_factor_temporal
222+ self .vae_scale_factor_spatial = self .vae .config .scale_factor_spatial
223+ else :
224+ self .vae_scale_factor_temporal = 4
225+ self .vae_scale_factor_spatial = 8
226+ else :
227+ self .vae_scale_factor_temporal = 4
228+ self .vae_scale_factor_spatial = 8
218229
219230 self .vae_scale_factor = self .vae_scale_factor_spatial
220231
@@ -687,14 +698,18 @@ def prepare_latents(
687698 image_latents = retrieve_latents (self .vae .encode (image ), sample_mode = "argmax" )
688699 image_latents = image_latents .repeat (batch_size , 1 , 1 , 1 , 1 )
689700
690- latents_mean = (
691- torch .tensor (self .vae .config .latents_mean )
692- .view (1 , - 1 , 1 , 1 , 1 )
693- .to (image_latents .device , image_latents .dtype )
694- )
695- latents_std = 1.0 / torch .tensor (self .vae .config .latents_std ).view (1 , - 1 , 1 , 1 , 1 ).to (
696- image_latents .device , image_latents .dtype
697- )
701+ if isinstance (self .vae , AutoencoderKLLTX2Video ):
702+ _latents_mean = self .vae .latents_mean
703+ _latents_std = self .vae .latents_std
704+ elif isinstance (self .vae , AutoencoderKLWan ):
705+ _latents_mean = torch .tensor (self .vae .config .latents_mean )
706+ _latents_std = torch .tensor (self .vae .config .latents_std )
707+ else :
708+ _latents_mean = torch .zeros (image_latents .shape [1 ], device = image_latents .device , dtype = image_latents .dtype )
709+ _latents_std = torch .ones (image_latents .shape [1 ], device = image_latents .device , dtype = image_latents .dtype )
710+
711+ latents_mean = _latents_mean .view (1 , - 1 , 1 , 1 , 1 ).to (image_latents .device , image_latents .dtype )
712+ latents_std = 1.0 / _latents_std .view (1 , - 1 , 1 , 1 , 1 ).to (image_latents .device , image_latents .dtype )
698713 image_latents = (image_latents - latents_mean ) * latents_std
699714
700715 latents [:, :, 0 :1 ] = image_latents .to (dtype )
@@ -1034,14 +1049,21 @@ def __call__(
10341049 if is_torch_version (">=" , "2.5.0" )
10351050 else torch_accelerator_module .OutOfMemoryError
10361051 )
1037- latents_mean = (
1038- torch .tensor (self .vae .config .latents_mean )
1039- .view (1 , self .vae .config .z_dim , 1 , 1 , 1 )
1040- .to (latents .device , latents .dtype )
1041- )
1042- latents_std = 1.0 / torch .tensor (self .vae .config .latents_std ).view (1 , self .vae .config .z_dim , 1 , 1 , 1 ).to (
1043- latents .device , latents .dtype
1044- )
1052+ if isinstance (self .vae , AutoencoderKLLTX2Video ):
1053+ latents_mean = self .vae .latents_mean
1054+ latents_std = self .vae .latents_std
1055+ z_dim = self .vae .config .latent_channels
1056+ elif isinstance (self .vae , AutoencoderKLWan ):
1057+ latents_mean = torch .tensor (self .vae .config .latents_mean )
1058+ latents_std = torch .tensor (self .vae .config .latents_std )
1059+ z_dim = self .vae .config .z_dim
1060+ else :
1061+ latents_mean = torch .zeros (latents .shape [1 ], device = latents .device , dtype = latents .dtype )
1062+ latents_std = torch .ones (latents .shape [1 ], device = latents .device , dtype = latents .dtype )
1063+ z_dim = latents .shape [1 ]
1064+
1065+ latents_mean = latents_mean .view (1 , z_dim , 1 , 1 , 1 ).to (latents .device , latents .dtype )
1066+ latents_std = 1.0 / latents_std .view (1 , z_dim , 1 , 1 , 1 ).to (latents .device , latents .dtype )
10451067 latents = latents / latents_std + latents_mean
10461068 try :
10471069 video = self .vae .decode (latents , return_dict = False )[0 ]
0 commit comments