@@ -5829,6 +5829,217 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
58295829 super ().unfuse_lora (components = components , ** kwargs )
58305830
58315831
5832+ class ErnieImageLoraLoaderMixin (LoraBaseMixin ):
5833+ r"""
5834+ Load LoRA layers into [`ErnieImageTransformer2DModel`]. Specific to [`ErnieImagePipeline`].
5835+ """
5836+
5837+ _lora_loadable_modules = ["transformer" ]
5838+ transformer_name = TRANSFORMER_NAME
5839+
5840+ @classmethod
5841+ @validate_hf_hub_args
5842+ def lora_state_dict (
5843+ cls ,
5844+ pretrained_model_name_or_path_or_dict : str | dict [str , torch .Tensor ],
5845+ ** kwargs ,
5846+ ):
5847+ r"""
5848+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
5849+ """
5850+ # Load the main state dict first which has the LoRA layers for either of
5851+ # transformer and text encoder or both.
5852+ cache_dir = kwargs .pop ("cache_dir" , None )
5853+ force_download = kwargs .pop ("force_download" , False )
5854+ proxies = kwargs .pop ("proxies" , None )
5855+ local_files_only = kwargs .pop ("local_files_only" , None )
5856+ token = kwargs .pop ("token" , None )
5857+ revision = kwargs .pop ("revision" , None )
5858+ subfolder = kwargs .pop ("subfolder" , None )
5859+ weight_name = kwargs .pop ("weight_name" , None )
5860+ use_safetensors = kwargs .pop ("use_safetensors" , None )
5861+ return_lora_metadata = kwargs .pop ("return_lora_metadata" , False )
5862+
5863+ allow_pickle = False
5864+ if use_safetensors is None :
5865+ use_safetensors = True
5866+ allow_pickle = True
5867+
5868+ user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
5869+
5870+ state_dict , metadata = _fetch_state_dict (
5871+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
5872+ weight_name = weight_name ,
5873+ use_safetensors = use_safetensors ,
5874+ local_files_only = local_files_only ,
5875+ cache_dir = cache_dir ,
5876+ force_download = force_download ,
5877+ proxies = proxies ,
5878+ token = token ,
5879+ revision = revision ,
5880+ subfolder = subfolder ,
5881+ user_agent = user_agent ,
5882+ allow_pickle = allow_pickle ,
5883+ )
5884+
5885+ is_dora_scale_present = any ("dora_scale" in k for k in state_dict )
5886+ if is_dora_scale_present :
5887+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
5888+ logger .warning (warn_msg )
5889+ state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
5890+
5891+ # PEFT format -> normalize to diffusion_model.* prefix
5892+ is_peft_format = any (k .startswith ("base_model.model." ) for k in state_dict )
5893+ if is_peft_format :
5894+ state_dict = {k .replace ("base_model.model." , "diffusion_model." ): v for k , v in state_dict .items ()}
5895+
5896+ # AI-Toolkit / diffusion_model.* prefix -> swap to transformer.*
5897+ # The Ernie LoRA naming under diffusion_model.* already matches diffusers module
5898+ # paths (layers.X.self_attention.to_q etc.), so only the prefix needs to change.
5899+ is_diffusion_model_prefix = any (k .startswith ("diffusion_model." ) for k in state_dict )
5900+ if is_diffusion_model_prefix :
5901+ state_dict = {k .replace ("diffusion_model." , "transformer." ): v for k , v in state_dict .items ()}
5902+
5903+ out = (state_dict , metadata ) if return_lora_metadata else state_dict
5904+ return out
5905+
5906+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
5907+ def load_lora_weights (
5908+ self ,
5909+ pretrained_model_name_or_path_or_dict : str | dict [str , torch .Tensor ],
5910+ adapter_name : str | None = None ,
5911+ hotswap : bool = False ,
5912+ ** kwargs ,
5913+ ):
5914+ """
5915+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
5916+ """
5917+ if not USE_PEFT_BACKEND :
5918+ raise ValueError ("PEFT backend is required for this method." )
5919+
5920+ low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
5921+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
5922+ raise ValueError (
5923+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
5924+ )
5925+
5926+ # if a dict is passed, copy it instead of modifying it inplace
5927+ if isinstance (pretrained_model_name_or_path_or_dict , dict ):
5928+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
5929+
5930+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
5931+ kwargs ["return_lora_metadata" ] = True
5932+ state_dict , metadata = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
5933+
5934+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
5935+ if not is_correct_format :
5936+ raise ValueError ("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring." )
5937+
5938+ self .load_lora_into_transformer (
5939+ state_dict ,
5940+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
5941+ adapter_name = adapter_name ,
5942+ metadata = metadata ,
5943+ _pipeline = self ,
5944+ low_cpu_mem_usage = low_cpu_mem_usage ,
5945+ hotswap = hotswap ,
5946+ )
5947+
5948+ @classmethod
5949+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->ErnieImageTransformer2DModel
5950+ def load_lora_into_transformer (
5951+ cls ,
5952+ state_dict ,
5953+ transformer ,
5954+ adapter_name = None ,
5955+ _pipeline = None ,
5956+ low_cpu_mem_usage = False ,
5957+ hotswap : bool = False ,
5958+ metadata = None ,
5959+ ):
5960+ """
5961+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
5962+ """
5963+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
5964+ raise ValueError (
5965+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
5966+ )
5967+
5968+ # Load the layers corresponding to transformer.
5969+ logger .info (f"Loading { cls .transformer_name } ." )
5970+ transformer .load_lora_adapter (
5971+ state_dict ,
5972+ network_alphas = None ,
5973+ adapter_name = adapter_name ,
5974+ metadata = metadata ,
5975+ _pipeline = _pipeline ,
5976+ low_cpu_mem_usage = low_cpu_mem_usage ,
5977+ hotswap = hotswap ,
5978+ )
5979+
5980+ @classmethod
5981+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
5982+ def save_lora_weights (
5983+ cls ,
5984+ save_directory : str | os .PathLike ,
5985+ transformer_lora_layers : dict [str , torch .nn .Module | torch .Tensor ] = None ,
5986+ is_main_process : bool = True ,
5987+ weight_name : str = None ,
5988+ save_function : Callable = None ,
5989+ safe_serialization : bool = True ,
5990+ transformer_lora_adapter_metadata : dict | None = None ,
5991+ ):
5992+ r"""
5993+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
5994+ """
5995+ lora_layers = {}
5996+ lora_metadata = {}
5997+
5998+ if transformer_lora_layers :
5999+ lora_layers [cls .transformer_name ] = transformer_lora_layers
6000+ lora_metadata [cls .transformer_name ] = transformer_lora_adapter_metadata
6001+
6002+ if not lora_layers :
6003+ raise ValueError ("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`." )
6004+
6005+ cls ._save_lora_weights (
6006+ save_directory = save_directory ,
6007+ lora_layers = lora_layers ,
6008+ lora_metadata = lora_metadata ,
6009+ is_main_process = is_main_process ,
6010+ weight_name = weight_name ,
6011+ save_function = save_function ,
6012+ safe_serialization = safe_serialization ,
6013+ )
6014+
6015+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
6016+ def fuse_lora (
6017+ self ,
6018+ components : list [str ] = ["transformer" ],
6019+ lora_scale : float = 1.0 ,
6020+ safe_fusing : bool = False ,
6021+ adapter_names : list [str ] | None = None ,
6022+ ** kwargs ,
6023+ ):
6024+ r"""
6025+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
6026+ """
6027+ super ().fuse_lora (
6028+ components = components ,
6029+ lora_scale = lora_scale ,
6030+ safe_fusing = safe_fusing ,
6031+ adapter_names = adapter_names ,
6032+ ** kwargs ,
6033+ )
6034+
6035+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
6036+ def unfuse_lora (self , components : list [str ] = ["transformer" ], ** kwargs ):
6037+ r"""
6038+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
6039+ """
6040+ super ().unfuse_lora (components = components , ** kwargs )
6041+
6042+
58326043class LoraLoaderMixin (StableDiffusionLoraLoaderMixin ):
58336044 def __init__ (self , * args , ** kwargs ):
58346045 deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
0 commit comments