Skip to content

Commit 74ef9d2

Browse files
fix: handle shard models in offline mode (#160)
1 parent bd6aa2c commit 74ef9d2

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

src/streamdiffusion/modules/controlnet_module.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,7 @@ def _prepare_control_image(self, control_image: Union[str, Any, torch.Tensor], p
574574
def _load_pytorch_controlnet_model(self, model_id: str, conditioning_channels: Optional[int] = None) -> ControlNetModel:
575575
from pathlib import Path
576576
import logging
577+
import os
577578
logger = logging.getLogger(__name__)
578579

579580
try:
@@ -582,6 +583,9 @@ def _load_pytorch_controlnet_model(self, model_id: str, conditioning_channels: O
582583
if conditioning_channels is not None:
583584
load_kwargs["conditioning_channels"] = conditioning_channels
584585

586+
# Check if offline mode is enabled via environment variables
587+
is_offline = os.environ.get("HF_HUB_OFFLINE", "0") == "1" or os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1"
588+
585589
if Path(model_id).exists():
586590
model_path = Path(model_id)
587591

@@ -601,6 +605,11 @@ def _load_pytorch_controlnet_model(self, model_id: str, conditioning_channels: O
601605
load_kwargs["local_files_only"] = True
602606
controlnet = ControlNetModel.from_pretrained(model_id, **load_kwargs)
603607
else:
608+
# Loading from HuggingFace Hub - respect offline mode
609+
if is_offline:
610+
load_kwargs["local_files_only"] = True
611+
logger.info(f"ControlNetModule._load_pytorch_controlnet_model: Offline mode enabled, loading '{model_id}' from cache only")
612+
604613
if "/" in model_id and model_id.count("/") > 1:
605614
parts = model_id.split("/")
606615
repo_id = "/".join(parts[:2])

0 commit comments

Comments
 (0)