diff --git a/scripts/trt.py b/scripts/trt.py index ff0bbc7..44203f1 100644 --- a/scripts/trt.py +++ b/scripts/trt.py @@ -349,8 +349,10 @@ def process_batch(self, p, *args, **kwargs): if hasattr(p, "controlnet") and sd_unet.current_unet is not None: sd_unet.current_unet.cnets = p.controlnet - else: + elif sd_unet.current_unet is not None: sd_unet.current_unet.cnets = None + else: + sd_unet.current_unet = None def before_hr(self, p, *args): if self.idx != self.hr_idx: