Skip to content

Commit a37f6f8

Browse files
hlkygithub-actions[bot]DN6
authored
Improve trust_remote_code (#13448)
* Robust trust check for custom_pipeline parameter of DiffusionPipeline.from_pretrained method * test_custom_components_from_local_dir * Apply style fixes * fix * Update src/diffusers/utils/dynamic_modules_utils.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * Adjust tests and allow community pipeline * DIFFUSERS_DISABLE_REMOTE_CODE --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent 267b7a0 commit a37f6f8

7 files changed

Lines changed: 178 additions & 27 deletions

File tree

src/diffusers/models/auto_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def from_config(cls, pretrained_model_name_or_path_or_dict: str | os.PathLike |
120120
subfolder=subfolder,
121121
module_file=module_file,
122122
class_name=class_name,
123+
trust_remote_code=trust_remote_code,
123124
**hub_kwargs,
124125
)
125126
else:
@@ -143,6 +144,7 @@ def from_config(cls, pretrained_model_name_or_path_or_dict: str | os.PathLike |
143144
importable_classes=ALL_IMPORTABLE_CLASSES,
144145
pipelines=None,
145146
is_pipeline_module=False,
147+
trust_remote_code=trust_remote_code,
146148
)
147149

148150
if model_cls is None:
@@ -318,6 +320,7 @@ def from_pretrained(cls, pretrained_model_or_path: str | os.PathLike | None = No
318320
subfolder=subfolder,
319321
module_file=module_file,
320322
class_name=class_name,
323+
trust_remote_code=trust_remote_code,
321324
**hub_kwargs,
322325
)
323326
else:

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def from_pretrained(
437437
pretrained_model_name_or_path,
438438
module_file=module_file,
439439
class_name=class_name,
440+
trust_remote_code=trust_remote_code,
440441
**hub_kwargs,
441442
)
442443
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,14 @@ def simple_get_class_obj(library_name, class_name):
410410

411411

412412
def get_class_obj_and_candidates(
413-
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
413+
library_name,
414+
class_name,
415+
importable_classes,
416+
pipelines,
417+
is_pipeline_module,
418+
component_name=None,
419+
cache_dir=None,
420+
trust_remote_code: bool = False,
414421
):
415422
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
416423
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
@@ -426,7 +433,10 @@ def get_class_obj_and_candidates(
426433
elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")):
427434
# load custom component
428435
class_obj = get_class_from_dynamic_module(
429-
component_folder, module_file=library_name + ".py", class_name=class_name
436+
component_folder,
437+
module_file=library_name + ".py",
438+
class_name=class_name,
439+
trust_remote_code=trust_remote_code,
430440
)
431441
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
432442
else:
@@ -450,6 +460,7 @@ def _get_custom_pipeline_class(
450460
class_name=None,
451461
cache_dir=None,
452462
revision=None,
463+
trust_remote_code: bool = False,
453464
):
454465
if custom_pipeline.endswith(".py"):
455466
path = Path(custom_pipeline)
@@ -473,6 +484,7 @@ def _get_custom_pipeline_class(
473484
class_name=class_name,
474485
cache_dir=cache_dir,
475486
revision=revision,
487+
trust_remote_code=trust_remote_code,
476488
)
477489

478490

@@ -486,6 +498,7 @@ def _get_pipeline_class(
486498
class_name=None,
487499
cache_dir=None,
488500
revision=None,
501+
trust_remote_code: bool = False,
489502
):
490503
if custom_pipeline is not None:
491504
return _get_custom_pipeline_class(
@@ -495,6 +508,7 @@ def _get_pipeline_class(
495508
class_name=class_name,
496509
cache_dir=cache_dir,
497510
revision=revision,
511+
trust_remote_code=trust_remote_code,
498512
)
499513

500514
if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline":
@@ -766,6 +780,7 @@ def load_sub_model(
766780
disable_mmap: bool,
767781
quantization_config: Any | None = None,
768782
use_flashpack: bool = False,
783+
trust_remote_code: bool = False,
769784
):
770785
"""Helper method to load the module `name` from `library_name` and `class_name`"""
771786
from ..quantizers import PipelineQuantizationConfig
@@ -780,6 +795,7 @@ def load_sub_model(
780795
is_pipeline_module,
781796
component_name=name,
782797
cache_dir=cached_folder,
798+
trust_remote_code=trust_remote_code,
783799
)
784800

785801
load_method_name = None

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwa
787787
quantization_config = kwargs.pop("quantization_config", None)
788788
use_flashpack = kwargs.pop("use_flashpack", False)
789789
disable_mmap = kwargs.pop("disable_mmap", False)
790+
trust_remote_code = kwargs.pop("trust_remote_code", False)
790791

791792
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
792793
torch_dtype = torch.float32
@@ -871,6 +872,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwa
871872
variant=variant,
872873
dduf_file=dduf_file,
873874
load_connected_pipeline=load_connected_pipeline,
875+
trust_remote_code=trust_remote_code,
874876
**kwargs,
875877
)
876878
else:
@@ -928,6 +930,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwa
928930
class_name=custom_class_name,
929931
cache_dir=cache_dir,
930932
revision=custom_revision,
933+
trust_remote_code=trust_remote_code,
931934
)
932935

933936
if device_map is not None and pipeline_class._load_connected_pipes:
@@ -1077,6 +1080,7 @@ def load_module(name, value):
10771080
disable_mmap=disable_mmap,
10781081
quantization_config=quantization_config,
10791082
use_flashpack=use_flashpack,
1083+
trust_remote_code=trust_remote_code,
10801084
)
10811085
logger.info(
10821086
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
@@ -1684,21 +1688,6 @@ def download(cls, pretrained_model_name, **kwargs) -> str | os.PathLike:
16841688
custom_class_name = config_dict["_class_name"][1]
16851689

16861690
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
1687-
load_components_from_hub = len(custom_components) > 0
1688-
1689-
if load_pipe_from_hub and not trust_remote_code:
1690-
raise ValueError(
1691-
f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly "
1692-
f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n"
1693-
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
1694-
)
1695-
1696-
if load_components_from_hub and not trust_remote_code:
1697-
raise ValueError(
1698-
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k, v in custom_components.items()])} which must be executed to correctly "
1699-
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k, v in custom_components.items()])}.\n"
1700-
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
1701-
)
17021691

17031692
# retrieve passed components that should not be downloaded
17041693
pipeline_class = _get_pipeline_class(
@@ -1711,6 +1700,7 @@ def download(cls, pretrained_model_name, **kwargs) -> str | os.PathLike:
17111700
class_name=custom_class_name,
17121701
cache_dir=cache_dir,
17131702
revision=custom_revision,
1703+
trust_remote_code=trust_remote_code,
17141704
)
17151705
expected_components, _ = cls._get_signature_keys(pipeline_class)
17161706
passed_components = [k for k in expected_components if k in kwargs]
@@ -2127,13 +2117,16 @@ def from_pipe(cls, pipeline, **kwargs):
21272117

21282118
original_config = dict(pipeline.config)
21292119
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
2120+
trust_remote_code = kwargs.pop("trust_remote_code", False)
21302121

21312122
# derive the pipeline class to instantiate
21322123
custom_pipeline = kwargs.pop("custom_pipeline", None)
21332124
custom_revision = kwargs.pop("custom_revision", None)
21342125

21352126
if custom_pipeline is not None:
2136-
pipeline_class = _get_custom_pipeline_class(custom_pipeline, revision=custom_revision)
2127+
pipeline_class = _get_custom_pipeline_class(
2128+
custom_pipeline, revision=custom_revision, trust_remote_code=trust_remote_code
2129+
)
21372130
else:
21382131
pipeline_class = cls
21392132

src/diffusers/utils/dynamic_modules_utils.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def get_cached_module_file(
254254
revision: str | None = None,
255255
local_files_only: bool = False,
256256
local_dir: str | None = None,
257+
trust_remote_code: bool = False,
257258
):
258259
"""
259260
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
@@ -289,6 +290,10 @@ def get_cached_module_file(
289290
identifier allowed by git.
290291
local_files_only (`bool`, *optional*, defaults to `False`):
291292
If `True`, will only try to load the tokenizer configuration from local files.
293+
trust_remote_code (`bool`, *optional*, defaults to `False`):
294+
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
295+
option should only be set to `True` for repositories you trust and in which you have read the code, as it
296+
will execute code present on the Hub on your local machine.
292297
293298
> [!TIP] > You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or
294299
[gated > models](https://huggingface.co/docs/hub/models-gated#gated-models).
@@ -299,15 +304,29 @@ def get_cached_module_file(
299304
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
300305
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
301306

307+
if DIFFUSERS_DISABLE_REMOTE_CODE:
308+
raise ValueError(
309+
"Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable."
310+
)
311+
302312
if subfolder is not None:
303313
module_file_or_url = os.path.join(pretrained_model_name_or_path, subfolder, module_file)
304314
else:
305315
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
306316

307-
if os.path.isfile(module_file_or_url):
317+
is_local_file = os.path.isfile(module_file_or_url)
318+
is_community_pipeline = not is_local_file and pretrained_model_name_or_path.count("/") == 0
319+
320+
if is_local_file:
308321
resolved_module_file = module_file_or_url
309322
submodule = "local"
310-
elif pretrained_model_name_or_path.count("/") == 0:
323+
if not trust_remote_code:
324+
raise ValueError(
325+
f"The directory {pretrained_model_name_or_path} contains custom code in {module_file} which must be executed to correctly "
326+
f"load the model. You can inspect the file content at {module_file_or_url}.\n"
327+
f"Pass `trust_remote_code=True` to allow loading remote code modules."
328+
)
329+
elif is_community_pipeline:
311330
available_versions = get_diffusers_versions()
312331
# cut ".dev0"
313332
latest_version = "v" + ".".join(__version__.split(".")[:3])
@@ -349,6 +368,12 @@ def get_cached_module_file(
349368
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
350369
raise
351370
else:
371+
if not trust_remote_code:
372+
raise ValueError(
373+
f"The repository for {pretrained_model_name_or_path} contains custom code in {module_file} which must be executed to correctly "
374+
f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name_or_path}/blob/main/{module_file}.\n"
375+
f"Pass `trust_remote_code=True` to allow loading remote code modules."
376+
)
352377
try:
353378
# Load from URL or cache if already cached
354379
resolved_module_file = hf_hub_download(
@@ -426,6 +451,7 @@ def get_cached_module_file(
426451
revision=revision,
427452
local_files_only=local_files_only,
428453
local_dir=local_dir,
454+
trust_remote_code=trust_remote_code,
429455
)
430456
return os.path.join(full_submodule, module_file)
431457

@@ -443,6 +469,7 @@ def get_class_from_dynamic_module(
443469
revision: str | None = None,
444470
local_files_only: bool = False,
445471
local_dir: str | None = None,
472+
trust_remote_code: bool = False,
446473
):
447474
"""
448475
Extracts a class from a module file, present in the local folder or repository of a model.
@@ -482,6 +509,10 @@ def get_class_from_dynamic_module(
482509
identifier allowed by git.
483510
local_files_only (`bool`, *optional*, defaults to `False`):
484511
If `True`, will only try to load the tokenizer configuration from local files.
512+
trust_remote_code (`bool`, *optional*, defaults to `False`):
513+
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
514+
option should only be set to `True` for repositories you trust and in which you have read the code, as it
515+
will execute code present on the Hub on your local machine.
485516
486517
> [!TIP] > You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or
487518
[gated > models](https://huggingface.co/docs/hub/models-gated#gated-models).
@@ -508,5 +539,6 @@ def get_class_from_dynamic_module(
508539
revision=revision,
509540
local_files_only=local_files_only,
510541
local_dir=local_dir,
542+
trust_remote_code=trust_remote_code,
511543
)
512544
return get_class_in_module(class_name, final_module)

tests/models/test_models_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def test_from_config_with_dict_diffusers_class(self, mock_get_class):
9999
importable_classes=unittest.mock.ANY,
100100
pipelines=None,
101101
is_pipeline_module=False,
102+
trust_remote_code=False,
102103
)
103104
mock_get_class.return_value[0].from_config.assert_called_once_with(config)
104105
assert result is mock_model
@@ -139,6 +140,7 @@ def test_from_config_with_model_type_routes_to_transformers(self, mock_get_class
139140
importable_classes=unittest.mock.ANY,
140141
pipelines=None,
141142
is_pipeline_module=False,
143+
trust_remote_code=False,
142144
)
143145
assert result is mock_model
144146

0 commit comments

Comments
 (0)