Skip to content

Commit 303c1d8

Browse files
asomozasayakpaulDN6
authored
[Ernie-Image] Add lora support (#13575)
add lora support Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent 716f246 commit 303c1d8

5 files changed

Lines changed: 222 additions & 2 deletions

File tree

docs/source/en/api/loaders/lora.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
3434
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
3535
- [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage).
3636
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
37+
- [`ErnieImageLoraLoaderMixin`] provides similar functions for [Ernie-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ernie_image).
3738
- [`LTX2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx2).
3839
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
3940

@@ -64,6 +65,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
6465

6566
[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
6667

68+
## ErnieImageLoraLoaderMixin
69+
70+
[[autodoc]] loaders.lora_pipeline.ErnieImageLoraLoaderMixin
71+
6772
## LTX2LoraLoaderMixin
6873

6974
[[autodoc]] loaders.lora_pipeline.LTX2LoraLoaderMixin

src/diffusers/loaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def text_encoder_attn_modules(text_encoder):
8585
"QwenImageLoraLoaderMixin",
8686
"ZImageLoraLoaderMixin",
8787
"Flux2LoraLoaderMixin",
88+
"ErnieImageLoraLoaderMixin",
8889
]
8990
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
9091
_import_structure["ip_adapter"] = [
@@ -117,6 +118,7 @@ def text_encoder_attn_modules(text_encoder):
117118
AuraFlowLoraLoaderMixin,
118119
CogVideoXLoraLoaderMixin,
119120
CogView4LoraLoaderMixin,
121+
ErnieImageLoraLoaderMixin,
120122
Flux2LoraLoaderMixin,
121123
FluxLoraLoaderMixin,
122124
HeliosLoraLoaderMixin,

src/diffusers/loaders/lora_pipeline.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5829,6 +5829,217 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
58295829
super().unfuse_lora(components=components, **kwargs)
58305830

58315831

5832+
class ErnieImageLoraLoaderMixin(LoraBaseMixin):
5833+
r"""
5834+
Load LoRA layers into [`ErnieImageTransformer2DModel`]. Specific to [`ErnieImagePipeline`].
5835+
"""
5836+
5837+
_lora_loadable_modules = ["transformer"]
5838+
transformer_name = TRANSFORMER_NAME
5839+
5840+
@classmethod
5841+
@validate_hf_hub_args
5842+
def lora_state_dict(
5843+
cls,
5844+
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
5845+
**kwargs,
5846+
):
5847+
r"""
5848+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
5849+
"""
5850+
# Load the main state dict first which has the LoRA layers for either of
5851+
# transformer and text encoder or both.
5852+
cache_dir = kwargs.pop("cache_dir", None)
5853+
force_download = kwargs.pop("force_download", False)
5854+
proxies = kwargs.pop("proxies", None)
5855+
local_files_only = kwargs.pop("local_files_only", None)
5856+
token = kwargs.pop("token", None)
5857+
revision = kwargs.pop("revision", None)
5858+
subfolder = kwargs.pop("subfolder", None)
5859+
weight_name = kwargs.pop("weight_name", None)
5860+
use_safetensors = kwargs.pop("use_safetensors", None)
5861+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
5862+
5863+
allow_pickle = False
5864+
if use_safetensors is None:
5865+
use_safetensors = True
5866+
allow_pickle = True
5867+
5868+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
5869+
5870+
state_dict, metadata = _fetch_state_dict(
5871+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
5872+
weight_name=weight_name,
5873+
use_safetensors=use_safetensors,
5874+
local_files_only=local_files_only,
5875+
cache_dir=cache_dir,
5876+
force_download=force_download,
5877+
proxies=proxies,
5878+
token=token,
5879+
revision=revision,
5880+
subfolder=subfolder,
5881+
user_agent=user_agent,
5882+
allow_pickle=allow_pickle,
5883+
)
5884+
5885+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
5886+
if is_dora_scale_present:
5887+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
5888+
logger.warning(warn_msg)
5889+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
5890+
5891+
# PEFT format -> normalize to diffusion_model.* prefix
5892+
is_peft_format = any(k.startswith("base_model.model.") for k in state_dict)
5893+
if is_peft_format:
5894+
state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()}
5895+
5896+
# AI-Toolkit / diffusion_model.* prefix -> swap to transformer.*
5897+
# The Ernie LoRA naming under diffusion_model.* already matches diffusers module
5898+
# paths (layers.X.self_attention.to_q etc.), so only the prefix needs to change.
5899+
is_diffusion_model_prefix = any(k.startswith("diffusion_model.") for k in state_dict)
5900+
if is_diffusion_model_prefix:
5901+
state_dict = {k.replace("diffusion_model.", "transformer."): v for k, v in state_dict.items()}
5902+
5903+
out = (state_dict, metadata) if return_lora_metadata else state_dict
5904+
return out
5905+
5906+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
5907+
def load_lora_weights(
5908+
self,
5909+
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
5910+
adapter_name: str | None = None,
5911+
hotswap: bool = False,
5912+
**kwargs,
5913+
):
5914+
"""
5915+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
5916+
"""
5917+
if not USE_PEFT_BACKEND:
5918+
raise ValueError("PEFT backend is required for this method.")
5919+
5920+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
5921+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
5922+
raise ValueError(
5923+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
5924+
)
5925+
5926+
# if a dict is passed, copy it instead of modifying it inplace
5927+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
5928+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
5929+
5930+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
5931+
kwargs["return_lora_metadata"] = True
5932+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
5933+
5934+
is_correct_format = all("lora" in key for key in state_dict.keys())
5935+
if not is_correct_format:
5936+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
5937+
5938+
self.load_lora_into_transformer(
5939+
state_dict,
5940+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5941+
adapter_name=adapter_name,
5942+
metadata=metadata,
5943+
_pipeline=self,
5944+
low_cpu_mem_usage=low_cpu_mem_usage,
5945+
hotswap=hotswap,
5946+
)
5947+
5948+
@classmethod
5949+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->ErnieImageTransformer2DModel
5950+
def load_lora_into_transformer(
5951+
cls,
5952+
state_dict,
5953+
transformer,
5954+
adapter_name=None,
5955+
_pipeline=None,
5956+
low_cpu_mem_usage=False,
5957+
hotswap: bool = False,
5958+
metadata=None,
5959+
):
5960+
"""
5961+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
5962+
"""
5963+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
5964+
raise ValueError(
5965+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
5966+
)
5967+
5968+
# Load the layers corresponding to transformer.
5969+
logger.info(f"Loading {cls.transformer_name}.")
5970+
transformer.load_lora_adapter(
5971+
state_dict,
5972+
network_alphas=None,
5973+
adapter_name=adapter_name,
5974+
metadata=metadata,
5975+
_pipeline=_pipeline,
5976+
low_cpu_mem_usage=low_cpu_mem_usage,
5977+
hotswap=hotswap,
5978+
)
5979+
5980+
@classmethod
5981+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
5982+
def save_lora_weights(
5983+
cls,
5984+
save_directory: str | os.PathLike,
5985+
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
5986+
is_main_process: bool = True,
5987+
weight_name: str = None,
5988+
save_function: Callable = None,
5989+
safe_serialization: bool = True,
5990+
transformer_lora_adapter_metadata: dict | None = None,
5991+
):
5992+
r"""
5993+
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
5994+
"""
5995+
lora_layers = {}
5996+
lora_metadata = {}
5997+
5998+
if transformer_lora_layers:
5999+
lora_layers[cls.transformer_name] = transformer_lora_layers
6000+
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
6001+
6002+
if not lora_layers:
6003+
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
6004+
6005+
cls._save_lora_weights(
6006+
save_directory=save_directory,
6007+
lora_layers=lora_layers,
6008+
lora_metadata=lora_metadata,
6009+
is_main_process=is_main_process,
6010+
weight_name=weight_name,
6011+
save_function=save_function,
6012+
safe_serialization=safe_serialization,
6013+
)
6014+
6015+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
6016+
def fuse_lora(
6017+
self,
6018+
components: list[str] = ["transformer"],
6019+
lora_scale: float = 1.0,
6020+
safe_fusing: bool = False,
6021+
adapter_names: list[str] | None = None,
6022+
**kwargs,
6023+
):
6024+
r"""
6025+
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
6026+
"""
6027+
super().fuse_lora(
6028+
components=components,
6029+
lora_scale=lora_scale,
6030+
safe_fusing=safe_fusing,
6031+
adapter_names=adapter_names,
6032+
**kwargs,
6033+
)
6034+
6035+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
6036+
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
6037+
r"""
6038+
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
6039+
"""
6040+
super().unfuse_lora(components=components, **kwargs)
6041+
6042+
58326043
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
58336044
def __init__(self, *args, **kwargs):
58346045
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."

src/diffusers/models/transformers/transformer_ernie_image.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import torch.nn.functional as F
2626

2727
from ...configuration_utils import ConfigMixin, register_to_config
28+
from ...loaders import PeftAdapterMixin
2829
from ...utils import BaseOutput, logging
2930
from ..attention import AttentionModuleMixin
3031
from ..attention_dispatch import dispatch_attention_fn
@@ -288,7 +289,7 @@ def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
288289
return x
289290

290291

291-
class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin):
292+
class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
292293
_supports_gradient_checkpointing = True
293294
_repeated_blocks = ["ErnieImageSharedAdaLNBlock"]
294295

src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from PIL import Image
2424
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
2525

26+
from ...loaders import ErnieImageLoraLoaderMixin
2627
from ...models import AutoencoderKLFlux2
2728
from ...models.transformers import ErnieImageTransformer2DModel
2829
from ...pipelines.pipeline_utils import DiffusionPipeline
@@ -31,7 +32,7 @@
3132
from .pipeline_output import ErnieImagePipelineOutput
3233

3334

34-
class ErnieImagePipeline(DiffusionPipeline):
35+
class ErnieImagePipeline(DiffusionPipeline, ErnieImageLoraLoaderMixin):
3536
"""
3637
Pipeline for text-to-image generation using ErnieImageTransformer2DModel.
3738

0 commit comments

Comments
 (0)