From 176ffee5da6a62d3d6aba5a83856d8faa63db346 Mon Sep 17 00:00:00 2001 From: Michael Aubry Date: Fri, 7 Feb 2025 00:18:29 -0800 Subject: [PATCH 1/8] fix: allow to process multiple images --- src/rp_handler.py | 56 +++++++++++++++++++++++------------------------ 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/src/rp_handler.py b/src/rp_handler.py index 314b9057..e6097c8e 100644 --- a/src/rp_handler.py +++ b/src/rp_handler.py @@ -232,45 +232,45 @@ def process_output_images(outputs, job_id): # The path where ComfyUI stores the generated images COMFY_OUTPUT_PATH = os.environ.get("COMFY_OUTPUT_PATH", "/comfyui/output") - output_images = {} + # Changed from a dictionary to a list to collect multiple image paths + output_images = [] for node_id, node_output in outputs.items(): if "images" in node_output: for image in node_output["images"]: - output_images = os.path.join(image["subfolder"], image["filename"]) + # Append each image's path (combining subfolder and filename) instead of overwriting + output_images.append(os.path.join(image["subfolder"], image["filename"])) print(f"runpod-worker-comfy - image generation is done") - # expected image output folder - local_image_path = f"{COMFY_OUTPUT_PATH}/{output_images}" + processed_images = [] # List to store the processed images (either URL or base64) - print(f"runpod-worker-comfy - {local_image_path}") + # Process each image found in the outputs + for rel_path in output_images: + # expected image output folder + local_image_path = f"{COMFY_OUTPUT_PATH}/{rel_path}" - # The image is in the output folder - if os.path.exists(local_image_path): - if os.environ.get("BUCKET_ENDPOINT_URL", False): - # URL to image in AWS S3 - image = rp_upload.upload_image(job_id, local_image_path) - print( - "runpod-worker-comfy - the image was generated and uploaded to AWS S3" - ) + print(f"runpod-worker-comfy - {local_image_path}") + + # The image is in the output folder + if os.path.exists(local_image_path): + if os.environ.get("BUCKET_ENDPOINT_URL", False): + # URL to image in AWS S3 + image_result = rp_upload.upload_image(job_id, local_image_path) + print("runpod-worker-comfy - the image was generated and uploaded to AWS S3") + else: + # base64 image + image_result = base64_encode(local_image_path) + print("runpod-worker-comfy - the image was generated and converted to base64") + processed_images.append(image_result) else: - # base64 image - image = base64_encode(local_image_path) - print( - "runpod-worker-comfy - the image was generated and converted to base64" - ) + print("runpod-worker-comfy - the image does not exist in the output folder") + processed_images.append(f"the image does not exist in the specified output folder: {local_image_path}") - return { - "status": "success", - "message": image, - } - else: - print("runpod-worker-comfy - the image does not exist in the output folder") - return { - "status": "error", - "message": f"the image does not exist in the specified output folder: {local_image_path}", - } + return { + "status": "success", + "message": processed_images, + } def handler(job): From f042eac33118ac767f6a0e2e2ddba23ef47d37f0 Mon Sep 17 00:00:00 2001 From: Michael Aubry Date: Wed, 26 Feb 2025 14:18:53 -0800 Subject: [PATCH 2/8] chore: add our own redits queue system --- Dockerfile | 2 +- requirements.txt | 3 +- src/rp_handler.py | 289 ++++++++++++++++------------------------------ src/start.sh | 2 +- 4 files changed, 103 insertions(+), 193 deletions(-) diff --git a/Dockerfile b/Dockerfile index b13a7008..4f7f3d1c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,7 +33,7 @@ RUN /usr/bin/yes | comfy --workspace /comfyui install --cuda-version 11.8 --nvid WORKDIR /comfyui # Install runpod -RUN pip install runpod requests +RUN pip install runpod requests redis # Support for the network volume ADD src/extra_model_paths.yaml ./ diff --git a/requirements.txt b/requirements.txt index 1e4fbbdd..da890506 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -runpod==1.3.6 \ No newline at end of file +runpod==1.3.6 +redis \ No newline at end of file diff --git a/src/rp_handler.py b/src/rp_handler.py index e6097c8e..7bf194e0 100644 --- a/src/rp_handler.py +++ b/src/rp_handler.py @@ -8,6 +8,25 @@ import requests import base64 from io import BytesIO +import redis +from datetime import datetime +import logging +import uuid + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Redis configuration from environment variables +REDIS_HOST = os.environ.get("REDIS_HOST", "localhost") +REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379)) +r = redis.Redis( + host='redis-13524.fcrce180.us-east-1-1.ec2.redns.redis-cloud.com', + port=13524, + decode_responses=True, + username="default", + password="Z8w8yMLSTJ6HZqGUoIw4cnUsb36qQuWf", +) # Time to wait between API check attempts in milliseconds COMFY_API_AVAILABLE_INTERVAL_MS = 50 @@ -20,117 +39,79 @@ # Host where ComfyUI is running COMFY_HOST = "127.0.0.1:8188" # Enforce a clean state after each job is done -# see https://docs.runpod.io/docs/handler-additional-controls#refresh-worker REFRESH_WORKER = os.environ.get("REFRESH_WORKER", "false").lower() == "true" +def update_redis(job_id, state, workflow=None, result=None, error=None): + """Update job state in Redis.""" + job = { + "id": job_id, + "state": state, + "workflow": workflow if workflow is not None else {}, + "result": result, + "error": error, + "updated_at": datetime.utcnow().isoformat() + } + try: + r.set(f"job:{job_id}", json.dumps(job)) + logger.info(f"Updated job {job_id} state to {state}") + except Exception as e: + logger.error(f"Failed to update Redis for job {job_id}: {str(e)}") def validate_input(job_input): - """ - Validates the input for the handler function. - - Args: - job_input (dict): The input data to validate. - - Returns: - tuple: A tuple containing the validated data and an error message, if any. - The structure is (validated_data, error_message). - """ - # Validate if job_input is provided + """Validates the input for the handler function.""" if job_input is None: return None, "Please provide input" - # Check if input is a string and try to parse it as JSON if isinstance(job_input, str): try: job_input = json.loads(job_input) except json.JSONDecodeError: return None, "Invalid JSON format in input" - # Validate 'workflow' in input workflow = job_input.get("workflow") if workflow is None: return None, "Missing 'workflow' parameter" - # Validate 'images' in input, if provided images = job_input.get("images") if images is not None: if not isinstance(images, list) or not all( "name" in image and "image" in image for image in images ): - return ( - None, - "'images' must be a list of objects with 'name' and 'image' keys", - ) + return None, "'images' must be a list of objects with 'name' and 'image' keys" - # Return validated data and no error return {"workflow": workflow, "images": images}, None - def check_server(url, retries=500, delay=50): - """ - Check if a server is reachable via HTTP GET request - - Args: - - url (str): The URL to check - - retries (int, optional): The number of times to attempt connecting to the server. Default is 50 - - delay (int, optional): The time in milliseconds to wait between retries. Default is 500 - - Returns: - bool: True if the server is reachable within the given number of retries, otherwise False - """ - + """Check if a server is reachable via HTTP GET request.""" for i in range(retries): try: response = requests.get(url) - - # If the response status code is 200, the server is up and running if response.status_code == 200: - print(f"runpod-worker-comfy - API is reachable") + logger.info("ComfyUI API is reachable") return True - except requests.RequestException as e: - # If an exception occurs, the server may not be ready + except requests.RequestException: pass - - # Wait for the specified delay before retrying time.sleep(delay / 1000) - - print( - f"runpod-worker-comfy - Failed to connect to server at {url} after {retries} attempts." - ) + logger.error(f"Failed to connect to {url} after {retries} attempts") return False - def upload_images(images): - """ - Upload a list of base64 encoded images to the ComfyUI server using the /upload/image endpoint. - - Args: - images (list): A list of dictionaries, each containing the 'name' of the image and the 'image' as a base64 encoded string. - server_address (str): The address of the ComfyUI server. - - Returns: - list: A list of responses from the server for each image upload. - """ + """Upload a list of base64 encoded images to the ComfyUI server.""" if not images: return {"status": "success", "message": "No images to upload", "details": []} responses = [] upload_errors = [] - - print(f"runpod-worker-comfy - image(s) upload") + logger.info("Starting image upload") for image in images: name = image["name"] image_data = image["image"] blob = base64.b64decode(image_data) - - # Prepare the form data files = { "image": (name, BytesIO(blob), "image/png"), "overwrite": (None, "true"), } - - # POST request to upload the image response = requests.post(f"http://{COMFY_HOST}/upload/image", files=files) if response.status_code != 200: upload_errors.append(f"Error uploading {name}: {response.text}") @@ -138,213 +119,141 @@ def upload_images(images): responses.append(f"Successfully uploaded {name}") if upload_errors: - print(f"runpod-worker-comfy - image(s) upload with errors") + logger.error("Image upload completed with errors") return { "status": "error", "message": "Some images failed to upload", "details": upload_errors, } - - print(f"runpod-worker-comfy - image(s) upload complete") + logger.info("Image upload completed successfully") return { "status": "success", "message": "All images uploaded successfully", "details": responses, } - def queue_workflow(workflow): - """ - Queue a workflow to be processed by ComfyUI - - Args: - workflow (dict): A dictionary containing the workflow to be processed - - Returns: - dict: The JSON response from ComfyUI after processing the workflow - """ - - # The top level element "prompt" is required by ComfyUI + """Queue a workflow to be processed by ComfyUI.""" data = json.dumps({"prompt": workflow}).encode("utf-8") - req = urllib.request.Request(f"http://{COMFY_HOST}/prompt", data=data) return json.loads(urllib.request.urlopen(req).read()) - def get_history(prompt_id): - """ - Retrieve the history of a given prompt using its ID - - Args: - prompt_id (str): The ID of the prompt whose history is to be retrieved - - Returns: - dict: The history of the prompt, containing all the processing steps and results - """ + """Retrieve the history of a given prompt using its ID.""" with urllib.request.urlopen(f"http://{COMFY_HOST}/history/{prompt_id}") as response: return json.loads(response.read()) - def base64_encode(img_path): - """ - Returns base64 encoded image. - - Args: - img_path (str): The path to the image - - Returns: - str: The base64 encoded image - """ + """Returns base64 encoded image.""" with open(img_path, "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()).decode("utf-8") - return f"{encoded_string}" - + return base64.b64encode(image_file.read()).decode("utf-8") def process_output_images(outputs, job_id): - """ - This function takes the "outputs" from image generation and the job ID, - then determines the correct way to return the image, either as a direct URL - to an AWS S3 bucket or as a base64 encoded string, depending on the - environment configuration. - - Args: - outputs (dict): A dictionary containing the outputs from image generation, - typically includes node IDs and their respective output data. - job_id (str): The unique identifier for the job. - - Returns: - dict: A dictionary with the status ('success' or 'error') and the message, - which is either the URL to the image in the AWS S3 bucket or a base64 - encoded string of the image. In case of error, the message details the issue. - - The function works as follows: - - It first determines the output path for the images from an environment variable, - defaulting to "/comfyui/output" if not set. - - It then iterates through the outputs to find the filenames of the generated images. - - After confirming the existence of the image in the output folder, it checks if the - AWS S3 bucket is configured via the BUCKET_ENDPOINT_URL environment variable. - - If AWS S3 is configured, it uploads the image to the bucket and returns the URL. - - If AWS S3 is not configured, it encodes the image in base64 and returns the string. - - If the image file does not exist in the output folder, it returns an error status - with a message indicating the missing image file. - """ - - # The path where ComfyUI stores the generated images + """Process generated images and return as S3 URL or base64.""" COMFY_OUTPUT_PATH = os.environ.get("COMFY_OUTPUT_PATH", "/comfyui/output") - - # Changed from a dictionary to a list to collect multiple image paths output_images = [] for node_id, node_output in outputs.items(): if "images" in node_output: for image in node_output["images"]: - # Append each image's path (combining subfolder and filename) instead of overwriting output_images.append(os.path.join(image["subfolder"], image["filename"])) - print(f"runpod-worker-comfy - image generation is done") + logger.info("Image generation completed") + processed_images = [] - processed_images = [] # List to store the processed images (either URL or base64) - - # Process each image found in the outputs for rel_path in output_images: - # expected image output folder local_image_path = f"{COMFY_OUTPUT_PATH}/{rel_path}" + logger.info(f"Processing image at {local_image_path}") - print(f"runpod-worker-comfy - {local_image_path}") - - # The image is in the output folder if os.path.exists(local_image_path): if os.environ.get("BUCKET_ENDPOINT_URL", False): - # URL to image in AWS S3 image_result = rp_upload.upload_image(job_id, local_image_path) - print("runpod-worker-comfy - the image was generated and uploaded to AWS S3") + logger.info("Image uploaded to AWS S3") else: - # base64 image image_result = base64_encode(local_image_path) - print("runpod-worker-comfy - the image was generated and converted to base64") + logger.info("Image converted to base64") processed_images.append(image_result) else: - print("runpod-worker-comfy - the image does not exist in the output folder") - processed_images.append(f"the image does not exist in the specified output folder: {local_image_path}") + logger.error(f"Image not found at {local_image_path}") + processed_images.append(f"Image not found: {local_image_path}") return { "status": "success", "message": processed_images, } - def handler(job): - """ - The main function that handles a job of generating an image. - - This function validates the input, sends a prompt to ComfyUI for processing, - polls ComfyUI for result, and retrieves generated images. - - Args: - job (dict): A dictionary containing job details and input parameters. - - Returns: - dict: A dictionary containing either an error message or a success status with generated images. - """ + """Main handler for processing a job.""" + job_id = job.get("id", str(uuid.uuid4())) # Use provided job ID or generate one job_input = job["input"] - # Make sure that the input is valid + # Validate input validated_data, error_message = validate_input(job_input) if error_message: - return {"error": error_message} + update_redis(job_id, "FAILED", error=error_message) + return {"error": error_message, "job_id": job_id} - # Extract validated data workflow = validated_data["workflow"] images = validated_data.get("images") - # Make sure that the ComfyUI API is available - check_server( - f"http://{COMFY_HOST}", - COMFY_API_AVAILABLE_MAX_RETRIES, - COMFY_API_AVAILABLE_INTERVAL_MS, - ) + # Write initial state to Redis + update_redis(job_id, "NOT_STARTED", workflow=workflow) - # Upload images if they exist - upload_result = upload_images(images) + # Check ComfyUI availability + if not check_server(f"http://{COMFY_HOST}", COMFY_API_AVAILABLE_MAX_RETRIES, COMFY_API_AVAILABLE_INTERVAL_MS): + error = "ComfyUI API unavailable" + update_redis(job_id, "FAILED", error=error) + return {"error": error, "job_id": job_id} + # Upload images if provided + upload_result = upload_images(images) if upload_result["status"] == "error": - return upload_result + update_redis(job_id, "FAILED", error=upload_result["message"]) + return {**upload_result, "job_id": job_id} # Queue the workflow try: + update_redis(job_id, "IN_QUEUE") queued_workflow = queue_workflow(workflow) prompt_id = queued_workflow["prompt_id"] - print(f"runpod-worker-comfy - queued workflow with ID {prompt_id}") + logger.info(f"Queued workflow with prompt ID {prompt_id}") except Exception as e: - return {"error": f"Error queuing workflow: {str(e)}"} + error = f"Error queuing workflow: {str(e)}" + update_redis(job_id, "FAILED", error=error) + return {"error": error, "job_id": job_id} # Poll for completion - print(f"runpod-worker-comfy - wait until image generation is complete") + logger.info("Polling for image generation completion") retries = 0 try: while retries < COMFY_POLLING_MAX_RETRIES: history = get_history(prompt_id) - - # Exit the loop if we have found the history if prompt_id in history and history[prompt_id].get("outputs"): break - else: - # Wait before trying again - time.sleep(COMFY_POLLING_INTERVAL_MS / 1000) - retries += 1 + time.sleep(COMFY_POLLING_INTERVAL_MS / 1000) + retries += 1 else: - return {"error": "Max retries reached while waiting for image generation"} + error = "Max retries reached while waiting for image generation" + update_redis(job_id, "FAILED", error=error) + return {"error": error, "job_id": job_id} except Exception as e: - return {"error": f"Error waiting for image generation: {str(e)}"} - - # Get the generated image and return it as URL in an AWS bucket or as base64 - images_result = process_output_images(history[prompt_id].get("outputs"), job["id"]) - - result = {**images_result, "refresh_worker": REFRESH_WORKER} + error = f"Error polling for image generation: {str(e)}" + update_redis(job_id, "FAILED", error=error) + return {"error": error, "job_id": job_id} - return result + # Process output images + images_result = process_output_images(history[prompt_id].get("outputs"), job_id) + if images_result["status"] == "success": + update_redis(job_id, "COMPLETED", result=images_result["message"]) + else: + update_redis(job_id, "FAILED", error="Image processing failed") + return { + "job_id": job_id, + "status": images_result["status"], + "message": images_result["message"], + "refresh_worker": REFRESH_WORKER + } -# Start the handler only if this script is run directly if __name__ == "__main__": - runpod.serverless.start({"handler": handler}) + runpod.serverless.start({"handler": handler}) \ No newline at end of file diff --git a/src/start.sh b/src/start.sh index 6cb5cb84..ec445e49 100644 --- a/src/start.sh +++ b/src/start.sh @@ -13,7 +13,7 @@ if [ "$SERVE_API_LOCALLY" == "true" ]; then python3 -u /rp_handler.py --rp_serve_api --rp_api_host=0.0.0.0 else echo "runpod-worker-comfy: Starting ComfyUI" - python3 /comfyui/main.py --disable-auto-launch --disable-metadata & + python3 /comfyui/main.py --disable-auto-launch --disable-metadata --highvram & echo "runpod-worker-comfy: Starting RunPod Handler" python3 -u /rp_handler.py From f3cb6279f9f32a20da40a5ae582c50b0b58b93de Mon Sep 17 00:00:00 2001 From: Michael Aubry Date: Wed, 26 Feb 2025 16:53:45 -0800 Subject: [PATCH 3/8] chore: add preloading weights if needed --- README.md | 48 ++++++++++--------- src/rp_handler.py | 114 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 140 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 8222d0cc..2fc07218 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,9 @@ +# How to build and push + +docker build -t basedlabs/runpod-worker-comfy:1.3 --target base --target base --platform linux/amd64 . + +docker push basedlabs/runpod-worker-comfy:1.4 + # runpod-worker-comfy > [ComfyUI](https://github.com/comfyanonymous/ComfyUI) as a serverless API on [RunPod](https://www.runpod.io/) @@ -17,33 +23,33 @@ Read our article here: https://blib.la/blog/comfyui-on-runpod - [Quickstart](#quickstart) - [Features](#features) - [Config](#config) - * [Upload image to AWS S3](#upload-image-to-aws-s3) + - [Upload image to AWS S3](#upload-image-to-aws-s3) - [Use the Docker image on RunPod](#use-the-docker-image-on-runpod) - * [Create your template (optional)](#create-your-template-optional) - * [Create your endpoint](#create-your-endpoint) - * [GPU recommendations](#gpu-recommendations) + - [Create your template (optional)](#create-your-template-optional) + - [Create your endpoint](#create-your-endpoint) + - [GPU recommendations](#gpu-recommendations) - [API specification](#api-specification) - * [JSON Request Body](#json-request-body) - * [Fields](#fields) - + ["input.images"](#inputimages) + - [JSON Request Body](#json-request-body) + - [Fields](#fields) + - ["input.images"](#inputimages) - [Interact with your RunPod API](#interact-with-your-runpod-api) - * [Health status](#health-status) - * [Generate an image](#generate-an-image) - + [Example request for SDXL with cURL](#example-request-for-sdxl-with-curl) + - [Health status](#health-status) + - [Generate an image](#generate-an-image) + - [Example request for SDXL with cURL](#example-request-for-sdxl-with-curl) - [How to get the workflow from ComfyUI?](#how-to-get-the-workflow-from-comfyui) - [Bring Your Own Models and Nodes](#bring-your-own-models-and-nodes) - * [Network Volume](#network-volume) - * [Custom Docker Image](#custom-docker-image) - + [Adding Custom Models](#adding-custom-models) - + [Adding Custom Nodes](#adding-custom-nodes) - + [Building the Image](#building-the-image) + - [Network Volume](#network-volume) + - [Custom Docker Image](#custom-docker-image) + - [Adding Custom Models](#adding-custom-models) + - [Adding Custom Nodes](#adding-custom-nodes) + - [Building the Image](#building-the-image) - [Local testing](#local-testing) - * [Setup](#setup) - + [Setup for Windows](#setup-for-windows) - * [Testing the RunPod handler](#testing-the-runpod-handler) - * [Local API](#local-api) - + [Access the local Worker API](#access-the-local-worker-api) - + [Access local ComfyUI](#access-local-comfyui) + - [Setup](#setup) + - [Setup for Windows](#setup-for-windows) + - [Testing the RunPod handler](#testing-the-runpod-handler) + - [Local API](#local-api) + - [Access the local Worker API](#access-the-local-worker-api) + - [Access local ComfyUI](#access-local-comfyui) - [Automatically deploy to Docker hub with GitHub Actions](#automatically-deploy-to-docker-hub-with-github-actions) - [Acknowledgments](#acknowledgments) diff --git a/src/rp_handler.py b/src/rp_handler.py index 7bf194e0..61b62a46 100644 --- a/src/rp_handler.py +++ b/src/rp_handler.py @@ -12,11 +12,17 @@ from datetime import datetime import logging import uuid +import sys # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +# Add ComfyUI directory to Python path +comfy_path = "/comfyui" +if comfy_path not in sys.path: + sys.path.append(comfy_path) + # Redis configuration from environment variables REDIS_HOST = os.environ.get("REDIS_HOST", "localhost") REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379)) @@ -255,5 +261,111 @@ def handler(job): "refresh_worker": REFRESH_WORKER } +def preload_weights(checkpoint_name): + try: + # Import ComfyUI modules after adding to path + from comfy.sd import load_checkpoint_guess_config + import folder_paths + + # Load the model + logger.info(f"Attempting to preload: {checkpoint_name}") + + # Use the exact path to the model in the RunPod volume + model_path = os.path.join('runpod-volume/models/checkpoints', checkpoint_name) + + if not os.path.exists(model_path): + logger.error(f"Model file not found at: {model_path}") + return None + + logger.info(f"Found model at: {model_path}") + + # Get the embedding directory from ComfyUI's folder_paths + embedding_directory = None + try: + embedding_directory = folder_paths.get_folder_paths("embeddings") + logger.info(f"Embedding directory: {embedding_directory}") + except Exception as e: + logger.warning(f"Could not get embedding directory: {str(e)}") + + # Call the function with the correct arguments as used in ComfyUI + logger.info(f"Calling load_checkpoint_guess_config with path: {model_path}") + result = load_checkpoint_guess_config( + model_path, + output_vae=True, + output_clip=True, + embedding_directory=embedding_directory + ) + + # Log the type of result to help with debugging + logger.info(f"Result type: {type(result)}") + + if isinstance(result, tuple): + logger.info(f"Result has {len(result)} elements") + # If it's a tuple, it might be (model, clip, vae) or (model, clip, vae, ...) + if len(result) >= 3: + model, clip, vae = result[:3] + logger.info(f"Successfully preloaded: {checkpoint_name}") + return model, clip, vae + else: + logger.error(f"Not enough values returned: expected at least 3, got {len(result)}") + return None + else: + # If it's not a tuple, it might be a single model object + logger.info(f"Result is not a tuple, treating as single model object") + return result, None, None + + except Exception as e: + logger.error(f"Error preloading model {checkpoint_name}: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return None + +def init(): + """Initialize the handler.""" + # Wait for ComfyUI to start up + logger.info("Waiting for ComfyUI to initialize...") + time.sleep(15) # Give ComfyUI time to start + + # Check if models directory exists + model_dir = 'runpod-volume/models/checkpoints' + if os.path.exists(model_dir): + logger.info(f"Model directory exists: {model_dir}") + try: + models = os.listdir(model_dir) + logger.info(f"Available models: {models}") + except Exception as e: + logger.error(f"Error listing models: {str(e)}") + else: + logger.warning(f"Model directory does not exist: {model_dir}") + + # Models to preload + models_to_preload = [ + "realvisxlV40_v40LightningBakedvae.safetensors", + "pixelArtDiffusionXL_pixelWorld.safetensors" + ] + + # Preload models + preloaded = [] + for checkpoint_name in models_to_preload: + try: + result = preload_weights(checkpoint_name) + if result is not None: + preloaded.append(checkpoint_name) + logger.info(f"Successfully preloaded: {checkpoint_name}") + except Exception as e: + logger.error(f"Failed to preload {checkpoint_name}: {str(e)}") + + logger.info(f"Preloaded {len(preloaded)} models: {preloaded}") + +# Only initialize if this module is the main program if __name__ == "__main__": - runpod.serverless.start({"handler": handler}) \ No newline at end of file + # Initialize after a delay to ensure ComfyUI is running + import threading + threading.Timer(5.0, init).start() + + # Start the handler + logger.info("Starting RunPod handler") + runpod.serverless.start({"handler": handler}) +else: + # If imported as a module, don't run init() immediately + logger.info("rp_handler module imported, initialization will be handled by main process") \ No newline at end of file From 39c7b275c5fb089b94a2f0fff7453605fc05c933 Mon Sep 17 00:00:00 2001 From: Michael Aubry Date: Thu, 27 Feb 2025 09:47:09 -0800 Subject: [PATCH 4/8] chore: add support for trainings and preload fluxdev --- src/rp_handler.py | 542 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 499 insertions(+), 43 deletions(-) diff --git a/src/rp_handler.py b/src/rp_handler.py index 61b62a46..acb3403e 100644 --- a/src/rp_handler.py +++ b/src/rp_handler.py @@ -46,6 +46,7 @@ COMFY_HOST = "127.0.0.1:8188" # Enforce a clean state after each job is done REFRESH_WORKER = os.environ.get("REFRESH_WORKER", "false").lower() == "true" +FLUX = os.environ.get("FLUX", "false").lower() == "true" def update_redis(job_id, state, workflow=None, result=None, error=None): """Update job state in Redis.""" @@ -84,8 +85,16 @@ def validate_input(job_input): "name" in image and "image" in image for image in images ): return None, "'images' must be a list of objects with 'name' and 'image' keys" + + # Validate loras parameter if provided + loras = job_input.get("loras") + if loras is not None: + if not isinstance(loras, list) or not all( + "path" in lora and "scale" in lora for lora in loras + ): + return None, "'loras' must be a list of objects with 'path' and 'scale' keys" - return {"workflow": workflow, "images": images}, None + return {"workflow": workflow, "images": images, "loras": loras}, None def check_server(url, retries=500, delay=50): """Check if a server is reachable via HTTP GET request.""" @@ -188,10 +197,67 @@ def process_output_images(outputs, job_id): "message": processed_images, } +def verify_lora_files(lora_info): + """Verify that downloaded lora files are accessible to ComfyUI. + + Args: + lora_info: List of dictionaries with lora information + + Returns: + bool: True if all files are accessible, False otherwise + """ + try: + # Import ComfyUI's folder_paths module to check lora paths + import folder_paths + + # Get the lora folder paths from ComfyUI + lora_folders = folder_paths.get_folder_paths("loras") + logger.info(f"ComfyUI lora folders: {lora_folders}") + + # Check if our lora directory is in the ComfyUI lora folders + our_lora_dir = "/runpod-volume/models/loras" + if our_lora_dir not in lora_folders: + logger.warning(f"Our lora directory {our_lora_dir} is not in ComfyUI's lora folders") + # Try to add it to ComfyUI's folder paths + folder_paths.add_model_folder_path("loras", our_lora_dir) + logger.info(f"Added {our_lora_dir} to ComfyUI's lora folders") + + # Verify it was added + lora_folders = folder_paths.get_folder_paths("loras") + if our_lora_dir not in lora_folders: + logger.error(f"Failed to add {our_lora_dir} to ComfyUI's lora folders") + return False + + # Check if each downloaded lora file exists and is accessible + for lora in lora_info: + if lora.get("downloaded", False): + lora_path = lora.get("path") + if not os.path.exists(lora_path): + logger.error(f"Downloaded lora file does not exist: {lora_path}") + return False + + # Check if the file is readable + try: + with open(lora_path, 'rb') as f: + # Just read a small chunk to verify access + f.read(1024) + logger.info(f"Verified access to lora file: {lora_path}") + except Exception as e: + logger.error(f"Cannot read lora file {lora_path}: {str(e)}") + return False + + return True + except Exception as e: + logger.error(f"Error verifying lora files: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return False + def handler(job): """Main handler for processing a job.""" job_id = job.get("id", str(uuid.uuid4())) # Use provided job ID or generate one job_input = job["input"] + lora_file_paths = [] # Track downloaded lora files for cleanup # Validate input validated_data, error_message = validate_input(job_input) @@ -201,6 +267,7 @@ def handler(job): workflow = validated_data["workflow"] images = validated_data.get("images") + loras = validated_data.get("loras") # Write initial state to Redis update_redis(job_id, "NOT_STARTED", workflow=workflow) @@ -216,6 +283,88 @@ def handler(job): if upload_result["status"] == "error": update_redis(job_id, "FAILED", error=upload_result["message"]) return {**upload_result, "job_id": job_id} + + # Process loras if provided + if loras: + logger.info(f"Processing {len(loras)} loras") + update_redis(job_id, "PROCESSING_LORAS") + lora_result = download_lora_files(loras) + lora_file_paths = lora_result.get("file_paths", []) + lora_info = lora_result.get("lora_info", []) + + if lora_result["status"] == "error": + # Continue even if some files failed, but log the errors + logger.warning(f"Some loras failed to download: {lora_result['details']}") + update_redis(job_id, "LORA_PROCESSING_PARTIAL", error=lora_result["message"]) + else: + logger.info("All network loras processed successfully") + update_redis(job_id, "LORA_PROCESSING_COMPLETE") + + # Verify lora files are accessible to ComfyUI + if lora_info: + if not verify_lora_files(lora_info): + logger.warning("Some lora files may not be accessible to ComfyUI") + update_redis(job_id, "LORA_VERIFICATION_WARNING", error="Some lora files may not be accessible to ComfyUI") + else: + logger.info("All lora files are accessible to ComfyUI") + + # Update workflow with local paths for downloaded loras + if lora_info: + logger.info("Updating workflow with local lora paths") + try: + # Check if this is a FLUX workflow by looking for specific node types + is_flux_workflow = False + workflow_str = json.dumps(workflow) + + if '"class_type": "DualCLIPLoader"' in workflow_str and '"class_type": "UNETLoader"' in workflow_str: + is_flux_workflow = True + logger.info("Detected FLUX workflow structure") + + if is_flux_workflow: + # For FLUX workflows, we need to update the LoraLoader nodes + for node_id, node in workflow.items(): + if node.get("class_type") == "LoraLoader": + node_lora_name = node["inputs"].get("lora_name", "") + logger.info(f"Processing LoraLoader node {node_id} with lora_name: {node_lora_name}") + + # Find the corresponding lora in our downloaded list + for lora in lora_info: + if lora.get("downloaded", False): + original_path = lora.get("original_path", "") + original_filename = os.path.basename(original_path) + + # Check various matching conditions + if (node_lora_name == original_path or # Full path match + node_lora_name == original_filename or # Filename match + os.path.basename(node_lora_name) == original_filename): # Basename match + + # Update to use the local path's filename + node["inputs"]["lora_name"] = lora.get("lora_name") + logger.info(f"Updated LoraLoader node {node_id} to use {lora.get('lora_name')}") + break + else: + # For standard workflows, do a simple string replacement + path_mapping = {} + for lora in lora_info: + if lora.get("downloaded", False): + path_mapping[lora["original_path"]] = lora["path"] + # Also map the filename for cases where only the filename is used + filename = os.path.basename(lora["original_path"]) + if filename: + path_mapping[filename] = os.path.basename(lora["path"]) + + # Update the workflow + workflow_str = json.dumps(workflow) + for original_path, local_path in path_mapping.items(): + workflow_str = workflow_str.replace(f'"{original_path}"', f'"{local_path}"') + workflow = json.loads(workflow_str) + + logger.info("Successfully updated workflow with local lora paths") + except Exception as e: + logger.error(f"Error updating workflow with local lora paths: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + # Continue with the original workflow # Queue the workflow try: @@ -226,6 +375,12 @@ def handler(job): except Exception as e: error = f"Error queuing workflow: {str(e)}" update_redis(job_id, "FAILED", error=error) + + # Clean up lora files if any were downloaded + if lora_file_paths: + cleanup_result = cleanup_lora_files(lora_file_paths) + logger.info(f"Cleaned up lora files after error: {cleanup_result['message']}") + return {"error": error, "job_id": job_id} # Poll for completion @@ -241,14 +396,37 @@ def handler(job): else: error = "Max retries reached while waiting for image generation" update_redis(job_id, "FAILED", error=error) + + # Clean up lora files if any were downloaded + if lora_file_paths: + cleanup_result = cleanup_lora_files(lora_file_paths) + logger.info(f"Cleaned up lora files after timeout: {cleanup_result['message']}") + return {"error": error, "job_id": job_id} except Exception as e: error = f"Error polling for image generation: {str(e)}" update_redis(job_id, "FAILED", error=error) + + # Clean up lora files if any were downloaded + if lora_file_paths: + cleanup_result = cleanup_lora_files(lora_file_paths) + logger.info(f"Cleaned up lora files after polling error: {cleanup_result['message']}") + return {"error": error, "job_id": job_id} # Process output images images_result = process_output_images(history[prompt_id].get("outputs"), job_id) + + # Clean up lora files if any were downloaded + cleanup_info = {} + if lora_file_paths: + cleanup_result = cleanup_lora_files(lora_file_paths) + logger.info(f"Cleaned up lora files after successful inference: {cleanup_result['message']}") + cleanup_info = { + "lora_cleanup": cleanup_result["status"], + "lora_cleanup_details": cleanup_result.get("details", []) + } + if images_result["status"] == "success": update_redis(job_id, "COMPLETED", result=images_result["message"]) else: @@ -258,7 +436,8 @@ def handler(job): "job_id": job_id, "status": images_result["status"], "message": images_result["message"], - "refresh_worker": REFRESH_WORKER + "refresh_worker": REFRESH_WORKER, + **cleanup_info } def preload_weights(checkpoint_name): @@ -270,49 +449,133 @@ def preload_weights(checkpoint_name): # Load the model logger.info(f"Attempting to preload: {checkpoint_name}") - # Use the exact path to the model in the RunPod volume - model_path = os.path.join('runpod-volume/models/checkpoints', checkpoint_name) - - if not os.path.exists(model_path): - logger.error(f"Model file not found at: {model_path}") - return None - - logger.info(f"Found model at: {model_path}") - - # Get the embedding directory from ComfyUI's folder_paths - embedding_directory = None - try: - embedding_directory = folder_paths.get_folder_paths("embeddings") - logger.info(f"Embedding directory: {embedding_directory}") - except Exception as e: - logger.warning(f"Could not get embedding directory: {str(e)}") - - # Call the function with the correct arguments as used in ComfyUI - logger.info(f"Calling load_checkpoint_guess_config with path: {model_path}") - result = load_checkpoint_guess_config( - model_path, - output_vae=True, - output_clip=True, - embedding_directory=embedding_directory - ) - - # Log the type of result to help with debugging - logger.info(f"Result type: {type(result)}") - - if isinstance(result, tuple): - logger.info(f"Result has {len(result)} elements") - # If it's a tuple, it might be (model, clip, vae) or (model, clip, vae, ...) - if len(result) >= 3: - model, clip, vae = result[:3] - logger.info(f"Successfully preloaded: {checkpoint_name}") - return model, clip, vae - else: - logger.error(f"Not enough values returned: expected at least 3, got {len(result)}") + # Special case for FLUX models which have a different structure + if checkpoint_name == "flux1-dev.safetensors": + logger.info("FLUX model detected, using specialized loading approach") + + # Load UNet + unet_path = os.path.join('runpod-volume/models/unet', checkpoint_name) + if not os.path.exists(unet_path): + logger.error(f"FLUX UNet model not found at: {unet_path}") + return None + logger.info(f"Found FLUX UNet model at: {unet_path}") + + # Load VAE + vae_path = os.path.join('runpod-volume/models/vae', 'ae.safetensors') + if not os.path.exists(vae_path): + logger.error(f"FLUX VAE model not found at: {vae_path}") return None + logger.info(f"Found FLUX VAE model at: {vae_path}") + + # Load CLIP models + clip_l_path = os.path.join('runpod-volume/models/clip', 'clip_l.safetensors') + t5_path = os.path.join('runpod-volume/models/clip', 't5xxl_fp8_e4m3fn.safetensors') + if not os.path.exists(clip_l_path): + logger.error(f"FLUX CLIP_L model not found at: {clip_l_path}") + return None + if not os.path.exists(t5_path): + logger.error(f"FLUX T5 model not found at: {t5_path}") + return None + logger.info(f"Found FLUX CLIP models at: {clip_l_path} and {t5_path}") + + # Try to load the components using load_checkpoint_guess_config + unet_model = None + vae_model = None + clip_model = None + + try: + # Load UNet + logger.info(f"Attempting to load FLUX UNet with load_checkpoint_guess_config: {unet_path}") + unet_model = load_checkpoint_guess_config( + unet_path, + output_vae=False, + output_clip=False + ) + logger.info(f"Successfully loaded FLUX UNet: {type(unet_model)}") + except Exception as e: + logger.error(f"Error loading FLUX UNet: {str(e)}") + + try: + # Load VAE + logger.info(f"Attempting to load FLUX VAE: {vae_path}") + vae_model = load_checkpoint_guess_config( + vae_path, + output_vae=True, + output_clip=False + ) + logger.info(f"Successfully loaded FLUX VAE: {type(vae_model)}") + except Exception as e: + logger.error(f"Error loading FLUX VAE: {str(e)}") + + try: + # Load CLIP models + logger.info(f"Attempting to load FLUX CLIP models") + # Import CLIP loading function + from comfy.sd import load_clip + + # Try loading CLIP_L + logger.info(f"Loading CLIP_L model: {clip_l_path}") + clip_model = load_clip(clip_l_path, None) + logger.info(f"Successfully loaded CLIP_L: {type(clip_model)}") + + # Note: T5 model might require special handling and may not load with standard functions + # ComfyUI will handle loading it when needed + logger.info(f"T5 model will be loaded by ComfyUI when needed") + except Exception as e: + logger.error(f"Error loading CLIP models: {str(e)}") + + # Check if we loaded at least some components + if unet_model is not None or vae_model is not None or clip_model is not None: + logger.info(f"Successfully preloaded some FLUX model components") + return unet_model, clip_model, vae_model + else: + logger.warning("Failed to preload any FLUX model components") + logger.info("FLUX models will be loaded by ComfyUI when needed") + return "FLUX_MODEL_PLACEHOLDER" else: - # If it's not a tuple, it might be a single model object - logger.info(f"Result is not a tuple, treating as single model object") - return result, None, None + # Use the exact path to the model in the RunPod volume + model_path = os.path.join('runpod-volume/models/checkpoints', checkpoint_name) + + if not os.path.exists(model_path): + logger.error(f"Model file not found at: {model_path}") + return None + + logger.info(f"Found model at: {model_path}") + + # Get the embedding directory from ComfyUI's folder_paths + embedding_directory = None + try: + embedding_directory = folder_paths.get_folder_paths("embeddings") + logger.info(f"Embedding directory: {embedding_directory}") + except Exception as e: + logger.warning(f"Could not get embedding directory: {str(e)}") + + # Call the function with the correct arguments as used in ComfyUI + logger.info(f"Calling load_checkpoint_guess_config with path: {model_path}") + result = load_checkpoint_guess_config( + model_path, + output_vae=True, + output_clip=True, + embedding_directory=embedding_directory + ) + + # Log the type of result to help with debugging + logger.info(f"Result type: {type(result)}") + + if isinstance(result, tuple): + logger.info(f"Result has {len(result)} elements") + # If it's a tuple, it might be (model, clip, vae) or (model, clip, vae, ...) + if len(result) >= 3: + model, clip, vae = result[:3] + logger.info(f"Successfully preloaded: {checkpoint_name}") + return model, clip, vae + else: + logger.error(f"Not enough values returned: expected at least 3, got {len(result)}") + return None + else: + # If it's not a tuple, it might be a single model object + logger.info(f"Result is not a tuple, treating as single model object") + return result, None, None except Exception as e: logger.error(f"Error preloading model {checkpoint_name}: {str(e)}") @@ -338,11 +601,62 @@ def init(): else: logger.warning(f"Model directory does not exist: {model_dir}") + # Check if FLUX is enabled, also check the required directories for FLUX models + if FLUX: + # Check for UNet model + unet_dir = 'runpod-volume/models/unet' + if os.path.exists(unet_dir): + logger.info(f"Unet directory exists: {unet_dir}") + try: + unet_models = os.listdir(unet_dir) + logger.info(f"Available unet models: {unet_models}") + if "flux1-dev.safetensors" not in unet_models: + logger.warning("flux1-dev.safetensors not found in unet directory") + except Exception as e: + logger.error(f"Error listing unet models: {str(e)}") + else: + logger.warning(f"Unet directory does not exist: {unet_dir}") + + # Check for CLIP models + clip_dir = 'runpod-volume/models/clip' + if os.path.exists(clip_dir): + logger.info(f"CLIP directory exists: {clip_dir}") + try: + clip_models = os.listdir(clip_dir) + logger.info(f"Available CLIP models: {clip_models}") + required_clip_models = ["clip_l.safetensors", "t5xxl_fp8_e4m3fn.safetensors"] + for model in required_clip_models: + if model not in clip_models: + logger.warning(f"{model} not found in CLIP directory") + except Exception as e: + logger.error(f"Error listing CLIP models: {str(e)}") + else: + logger.warning(f"CLIP directory does not exist: {clip_dir}") + + # Check for VAE model + vae_dir = 'runpod-volume/models/vae' + if os.path.exists(vae_dir): + logger.info(f"VAE directory exists: {vae_dir}") + try: + vae_models = os.listdir(vae_dir) + logger.info(f"Available VAE models: {vae_models}") + if "ae.safetensors" not in vae_models: + logger.warning("ae.safetensors not found in VAE directory") + except Exception as e: + logger.error(f"Error listing VAE models: {str(e)}") + else: + logger.warning(f"VAE directory does not exist: {vae_dir}") + # Models to preload models_to_preload = [ "realvisxlV40_v40LightningBakedvae.safetensors", "pixelArtDiffusionXL_pixelWorld.safetensors" ] + + # Check if FLUX environment variable is set to true + if FLUX: + models_to_preload.append("flux1-dev.safetensors") + logger.info("FLUX environment variable is set to true, adding flux1-dev.safetensors to preload list") # Preload models preloaded = [] @@ -357,6 +671,148 @@ def init(): logger.info(f"Preloaded {len(preloaded)} models: {preloaded}") +def download_lora_files(loras): + """Download lora files from network paths to disk. + + Args: + loras: List of dictionaries with 'path' and 'scale' keys + + Returns: + dict: Status of download operation with downloaded file paths and lora info + """ + if not loras: + return {"status": "success", "message": "No loras to process", "details": [], "file_paths": [], "lora_info": []} + + LORA_DIR = "/runpod-volume/models/loras" + os.makedirs(LORA_DIR, exist_ok=True) + + responses = [] + download_errors = [] + file_paths = [] + lora_info = [] # Store info about all loras (both local and downloaded) + + logger.info(f"Processing {len(loras)} loras") + + for lora in loras: + path = lora["path"] + scale = lora["scale"] + + # Check if the lora is a network path (http or https) + if path.startswith(("http://", "https://")): + # Extract filename from URL + filename = os.path.basename(path) + if not filename: + # Generate a random filename if URL doesn't have one + filename = f"lora_{uuid.uuid4().hex}.safetensors" + elif not filename.endswith('.safetensors'): + filename = f"{filename}.safetensors" + + local_path = os.path.join(LORA_DIR, filename) + file_paths.append(local_path) + + try: + # Download the file + logger.info(f"Downloading lora from {path} to {local_path}") + response = requests.get(path, stream=True) + response.raise_for_status() + + with open(local_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + responses.append(f"Successfully downloaded lora from {path}") + logger.info(f"Successfully downloaded lora from {path}") + + # Add to lora info with local path + lora_info.append({ + "path": local_path, + "scale": scale, + "original_path": path, + "downloaded": True, + "lora_name": filename # Add lora_name for ComfyUI workflow + }) + except Exception as e: + error_msg = f"Error downloading lora from {path}: {str(e)}" + download_errors.append(error_msg) + logger.error(error_msg) + + # Still add to lora info but with original path + lora_info.append({ + "path": path, + "scale": scale, + "download_error": str(e), + "downloaded": False + }) + else: + # Local path, no need to download + filename = os.path.basename(path) + logger.info(f"Using local lora at {path}") + lora_info.append({ + "path": path, + "scale": scale, + "downloaded": False, + "lora_name": filename # Add lora_name for ComfyUI workflow + }) + + if download_errors: + logger.error("Lora download completed with errors") + return { + "status": "error", + "message": "Some loras failed to download", + "details": download_errors, + "file_paths": file_paths, + "lora_info": lora_info + } + + logger.info("Lora processing completed successfully") + return { + "status": "success", + "message": "All network loras downloaded successfully", + "details": responses, + "file_paths": file_paths, + "lora_info": lora_info + } + +def cleanup_lora_files(file_paths): + """Delete downloaded lora files after inference is complete. + + Args: + file_paths: List of file paths to delete + + Returns: + dict: Status of cleanup operation + """ + if not file_paths: + return {"status": "success", "message": "No lora files to clean up"} + + cleanup_errors = [] + deleted_files = [] + + for file_path in file_paths: + try: + if os.path.exists(file_path): + os.remove(file_path) + deleted_files.append(file_path) + logger.info(f"Successfully deleted lora file: {file_path}") + except Exception as e: + error_msg = f"Error deleting {file_path}: {str(e)}" + cleanup_errors.append(error_msg) + logger.error(error_msg) + + if cleanup_errors: + return { + "status": "warning", + "message": "Some lora files could not be deleted", + "details": cleanup_errors, + "deleted_files": deleted_files + } + + return { + "status": "success", + "message": "All lora files cleaned up successfully", + "details": deleted_files + } + # Only initialize if this module is the main program if __name__ == "__main__": # Initialize after a delay to ensure ComfyUI is running From d7e439b0f582b1a47609b2f2bfad8d4f9f3c90a8 Mon Sep 17 00:00:00 2001 From: Michael Aubry Date: Wed, 5 Mar 2025 16:27:04 +0100 Subject: [PATCH 5/8] chore: add webhooks --- src/rp_handler.py | 201 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 195 insertions(+), 6 deletions(-) diff --git a/src/rp_handler.py b/src/rp_handler.py index acb3403e..ae9ecb85 100644 --- a/src/rp_handler.py +++ b/src/rp_handler.py @@ -93,8 +93,13 @@ def validate_input(job_input): "path" in lora and "scale" in lora for lora in loras ): return None, "'loras' must be a list of objects with 'path' and 'scale' keys" + + # Get optional webhook URL + webhook_url = job_input.get("webhookUrl") + if webhook_url is not None and not isinstance(webhook_url, str): + return None, "'webhookUrl' must be a string" - return {"workflow": workflow, "images": images, "loras": loras}, None + return {"workflow": workflow, "images": images, "loras": loras, "webhook_url": webhook_url}, None def check_server(url, retries=500, delay=50): """Check if a server is reachable via HTTP GET request.""" @@ -258,6 +263,7 @@ def handler(job): job_id = job.get("id", str(uuid.uuid4())) # Use provided job ID or generate one job_input = job["input"] lora_file_paths = [] # Track downloaded lora files for cleanup + start_time = datetime.utcnow().isoformat() # Validate input validated_data, error_message = validate_input(job_input) @@ -268,6 +274,7 @@ def handler(job): workflow = validated_data["workflow"] images = validated_data.get("images") loras = validated_data.get("loras") + webhook_url = validated_data.get("webhook_url") # Write initial state to Redis update_redis(job_id, "NOT_STARTED", workflow=workflow) @@ -276,20 +283,43 @@ def handler(job): if not check_server(f"http://{COMFY_HOST}", COMFY_API_AVAILABLE_MAX_RETRIES, COMFY_API_AVAILABLE_INTERVAL_MS): error = "ComfyUI API unavailable" update_redis(job_id, "FAILED", error=error) + + # Call webhook with error if provided + if webhook_url: + call_webhook(webhook_url, job_id, "error", error=error, additional_data={ + "start_time": start_time, + "end_time": datetime.utcnow().isoformat(), + "error_type": "comfy_unavailable" + }) + return {"error": error, "job_id": job_id} # Upload images if provided upload_result = upload_images(images) if upload_result["status"] == "error": update_redis(job_id, "FAILED", error=upload_result["message"]) + + # Call webhook with error if provided + if webhook_url: + call_webhook(webhook_url, job_id, "error", error=upload_result["message"], additional_data={ + "start_time": start_time, + "end_time": datetime.utcnow().isoformat(), + "error_type": "image_upload_failed", + "upload_details": upload_result.get("details", []) + }) + return {**upload_result, "job_id": job_id} # Process loras if provided + lora_info = [] if loras: logger.info(f"Processing {len(loras)} loras") update_redis(job_id, "PROCESSING_LORAS") lora_result = download_lora_files(loras) + + # Only track paths of downloaded loras, not local ones lora_file_paths = lora_result.get("file_paths", []) + logger.info(f"Tracking {len(lora_file_paths)} downloaded lora files for cleanup: {lora_file_paths}") lora_info = lora_result.get("lora_info", []) if lora_result["status"] == "error": @@ -380,6 +410,16 @@ def handler(job): if lora_file_paths: cleanup_result = cleanup_lora_files(lora_file_paths) logger.info(f"Cleaned up lora files after error: {cleanup_result['message']}") + logger.info(f"Deleted {len(cleanup_result.get('details', []))} files, skipped {len(cleanup_result.get('skipped_files', []))} files") + + # Call webhook with error if provided + if webhook_url: + call_webhook(webhook_url, job_id, "error", error=error, additional_data={ + "start_time": start_time, + "end_time": datetime.utcnow().isoformat(), + "error_type": "workflow_queue_failed", + "lora_count": len(lora_info) if lora_info else 0 + }) return {"error": error, "job_id": job_id} @@ -401,6 +441,17 @@ def handler(job): if lora_file_paths: cleanup_result = cleanup_lora_files(lora_file_paths) logger.info(f"Cleaned up lora files after timeout: {cleanup_result['message']}") + logger.info(f"Deleted {len(cleanup_result.get('details', []))} files, skipped {len(cleanup_result.get('skipped_files', []))} files") + + # Call webhook with error if provided + if webhook_url: + call_webhook(webhook_url, job_id, "error", error=error, additional_data={ + "start_time": start_time, + "end_time": datetime.utcnow().isoformat(), + "error_type": "generation_timeout", + "retries": retries, + "lora_count": len(lora_info) if lora_info else 0 + }) return {"error": error, "job_id": job_id} except Exception as e: @@ -411,6 +462,16 @@ def handler(job): if lora_file_paths: cleanup_result = cleanup_lora_files(lora_file_paths) logger.info(f"Cleaned up lora files after polling error: {cleanup_result['message']}") + logger.info(f"Deleted {len(cleanup_result.get('details', []))} files, skipped {len(cleanup_result.get('skipped_files', []))} files") + + # Call webhook with error if provided + if webhook_url: + call_webhook(webhook_url, job_id, "error", error=error, additional_data={ + "start_time": start_time, + "end_time": datetime.utcnow().isoformat(), + "error_type": "polling_error", + "lora_count": len(lora_info) if lora_info else 0 + }) return {"error": error, "job_id": job_id} @@ -422,15 +483,57 @@ def handler(job): if lora_file_paths: cleanup_result = cleanup_lora_files(lora_file_paths) logger.info(f"Cleaned up lora files after successful inference: {cleanup_result['message']}") + logger.info(f"Deleted {len(cleanup_result.get('details', []))} files, skipped {len(cleanup_result.get('skipped_files', []))} files") cleanup_info = { "lora_cleanup": cleanup_result["status"], - "lora_cleanup_details": cleanup_result.get("details", []) + "lora_cleanup_details": cleanup_result.get("details", []), + "lora_skipped_files": cleanup_result.get("skipped_files", []) } + end_time = datetime.utcnow().isoformat() + if images_result["status"] == "success": update_redis(job_id, "COMPLETED", result=images_result["message"]) + + # Call webhook with success result if provided + if webhook_url: + # Prepare additional data for webhook + additional_data = { + "start_time": start_time, + "end_time": end_time, + "processing_time_seconds": (datetime.fromisoformat(end_time) - datetime.fromisoformat(start_time)).total_seconds(), + "lora_count": len(lora_info) if lora_info else 0, + "image_count": len(images_result.get("message", [])), + "prompt_id": prompt_id + } + + webhook_result = call_webhook( + webhook_url, + job_id, + "success", + result=images_result["message"], + additional_data=additional_data + ) + cleanup_info["webhook_result"] = webhook_result else: update_redis(job_id, "FAILED", error="Image processing failed") + + # Call webhook with error if provided + if webhook_url: + webhook_result = call_webhook( + webhook_url, + job_id, + "error", + error="Image processing failed", + additional_data={ + "start_time": start_time, + "end_time": end_time, + "error_type": "image_processing_failed", + "lora_count": len(lora_info) if lora_info else 0, + "prompt_id": prompt_id + } + ) + cleanup_info["webhook_result"] = webhook_result return { "job_id": job_id, @@ -709,6 +812,7 @@ def download_lora_files(loras): local_path = os.path.join(LORA_DIR, filename) file_paths.append(local_path) + logger.info(f"Added downloaded lora path to file_paths: {local_path}") try: # Download the file @@ -746,7 +850,7 @@ def download_lora_files(loras): else: # Local path, no need to download filename = os.path.basename(path) - logger.info(f"Using local lora at {path}") + logger.info(f"Using local lora at {path} (not adding to file_paths for cleanup)") lora_info.append({ "path": path, "scale": scale, @@ -785,15 +889,29 @@ def cleanup_lora_files(file_paths): if not file_paths: return {"status": "success", "message": "No lora files to clean up"} + logger.info(f"Cleanup requested for {len(file_paths)} lora files: {file_paths}") + cleanup_errors = [] deleted_files = [] + skipped_files = [] + + # Define the directory where downloaded loras are stored + LORA_DIR = "/runpod-volume/models/loras" for file_path in file_paths: try: - if os.path.exists(file_path): + # Only delete files that are in the download directory + # This ensures we only delete files that were downloaded from URLs + if os.path.exists(file_path) and file_path.startswith(LORA_DIR): + logger.info(f"Deleting lora file: {file_path}") os.remove(file_path) deleted_files.append(file_path) logger.info(f"Successfully deleted lora file: {file_path}") + else: + # Skip files that don't exist or aren't in the download directory + reason = "non-existent file" if not os.path.exists(file_path) else "local file (not in download directory)" + logger.warning(f"Skipping deletion of {reason}: {file_path}") + skipped_files.append(file_path) except Exception as e: error_msg = f"Error deleting {file_path}: {str(e)}" cleanup_errors.append(error_msg) @@ -804,15 +922,86 @@ def cleanup_lora_files(file_paths): "status": "warning", "message": "Some lora files could not be deleted", "details": cleanup_errors, - "deleted_files": deleted_files + "deleted_files": deleted_files, + "skipped_files": skipped_files } return { "status": "success", "message": "All lora files cleaned up successfully", - "details": deleted_files + "details": deleted_files, + "skipped_files": skipped_files } +def call_webhook(webhook_url, job_id, status, result=None, error=None, additional_data=None): + """Call a webhook URL with the job results. + + Args: + webhook_url: URL to call + job_id: ID of the job + status: Status of the job (success, error) + result: Result of the job (if successful) + error: Error message (if failed) + additional_data: Additional data to include in the payload + + Returns: + dict: Status of webhook call + """ + if not webhook_url: + return {"status": "skipped", "message": "No webhook URL provided"} + + try: + logger.info(f"Calling webhook URL: {webhook_url}") + + # Prepare payload + payload = { + "job_id": job_id, + "status": status, + "timestamp": datetime.utcnow().isoformat() + } + + if result is not None: + payload["result"] = result + + if error is not None: + payload["error"] = error + + if additional_data is not None: + payload.update(additional_data) + + # Send POST request to webhook URL + response = requests.post( + webhook_url, + json=payload, + headers={"Content-Type": "application/json"}, + timeout=30 # 30 second timeout + ) + + # Check response + if response.status_code >= 200 and response.status_code < 300: + logger.info(f"Webhook call successful: {response.status_code}") + return { + "status": "success", + "message": f"Webhook call successful: {response.status_code}", + "response": response.text + } + else: + logger.error(f"Webhook call failed: {response.status_code} - {response.text}") + return { + "status": "error", + "message": f"Webhook call failed: {response.status_code}", + "response": response.text + } + except Exception as e: + error_msg = f"Error calling webhook: {str(e)}" + logger.error(error_msg) + import traceback + logger.error(traceback.format_exc()) + return { + "status": "error", + "message": error_msg + } + # Only initialize if this module is the main program if __name__ == "__main__": # Initialize after a delay to ensure ComfyUI is running From 1c9c8f4922b3b110a6862c2ebd3d2b67389f8bf1 Mon Sep 17 00:00:00 2001 From: Michael Aubry Date: Thu, 6 Mar 2025 07:56:14 +0100 Subject: [PATCH 6/8] chore: adding logs and better deltion logic --- src/rp_handler.py | 368 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 308 insertions(+), 60 deletions(-) diff --git a/src/rp_handler.py b/src/rp_handler.py index ae9ecb85..bf9378ed 100644 --- a/src/rp_handler.py +++ b/src/rp_handler.py @@ -9,15 +9,43 @@ import base64 from io import BytesIO import redis -from datetime import datetime +from datetime import datetime, timedelta import logging import uuid import sys +import traceback +from typing import Dict, List, Tuple, Any, Optional, Union # Configure logging -logging.basicConfig(level=logging.INFO) +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(name)s] [%(levelname)s] [%(job_id)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) logger = logging.getLogger(__name__) +# Add a filter to include job_id in log records +class JobIdFilter(logging.Filter): + """Filter that adds job_id to log records.""" + + def __init__(self, name=''): + super().__init__(name) + self.job_id = 'no_job' + + def set_job_id(self, job_id): + """Set the current job ID.""" + self.job_id = job_id + + def filter(self, record): + """Add job_id to the log record.""" + if not hasattr(record, 'job_id'): + record.job_id = self.job_id + return True + +# Create and add the filter to the logger +job_id_filter = JobIdFilter() +logger.addFilter(job_id_filter) + # Add ComfyUI directory to Python path comfy_path = "/comfyui" if comfy_path not in sys.path: @@ -50,6 +78,9 @@ def update_redis(job_id, state, workflow=None, result=None, error=None): """Update job state in Redis.""" + # Set job ID for logging + job_id_filter.set_job_id(job_id) + job = { "id": job_id, "state": state, @@ -60,9 +91,9 @@ def update_redis(job_id, state, workflow=None, result=None, error=None): } try: r.set(f"job:{job_id}", json.dumps(job)) - logger.info(f"Updated job {job_id} state to {state}") + logger.info(f"Updated job state to {state}") except Exception as e: - logger.error(f"Failed to update Redis for job {job_id}: {str(e)}") + logger.error(f"Failed to update Redis: {str(e)}") def validate_input(job_input): """Validates the input for the handler function.""" @@ -101,22 +132,49 @@ def validate_input(job_input): return {"workflow": workflow, "images": images, "loras": loras, "webhook_url": webhook_url}, None -def check_server(url, retries=500, delay=50): - """Check if a server is reachable via HTTP GET request.""" - for i in range(retries): +def check_server(url, retries=500, delay=50, job_id=None): + """Check if ComfyUI API is available. + + Args: + url: URL to check + retries: Number of retries + delay: Delay between retries in milliseconds + job_id: ID of the job for logging + + Returns: + bool: True if API is available, False otherwise + """ + # Set job ID for logging if provided + if job_id: + job_id_filter.set_job_id(job_id) + + for _ in range(retries): try: - response = requests.get(url) + response = requests.get(f"{url}/system_stats", timeout=5) if response.status_code == 200: logger.info("ComfyUI API is reachable") return True - except requests.RequestException: + except: pass time.sleep(delay / 1000) + logger.error(f"Failed to connect to {url} after {retries} attempts") return False -def upload_images(images): - """Upload a list of base64 encoded images to the ComfyUI server.""" +def upload_images(images, job_id=None): + """Upload images to ComfyUI. + + Args: + images: List of image data (base64 or URLs) + job_id: ID of the job for logging + + Returns: + dict: Status of upload operation + """ + # Set job ID for logging if provided + if job_id: + job_id_filter.set_job_id(job_id) + if not images: return {"status": "success", "message": "No images to upload", "details": []} @@ -152,16 +210,41 @@ def upload_images(images): "details": responses, } -def queue_workflow(workflow): - """Queue a workflow to be processed by ComfyUI.""" - data = json.dumps({"prompt": workflow}).encode("utf-8") - req = urllib.request.Request(f"http://{COMFY_HOST}/prompt", data=data) - return json.loads(urllib.request.urlopen(req).read()) +def queue_workflow(workflow, job_id=None): + """Queue a workflow in ComfyUI. + + Args: + workflow: Workflow to queue + job_id: ID of the job for logging + + Returns: + dict: Response from ComfyUI API + """ + # Set job ID for logging if provided + if job_id: + job_id_filter.set_job_id(job_id) + + response = requests.post(f"http://{COMFY_HOST}/prompt", json=workflow) + response.raise_for_status() + return response.json() -def get_history(prompt_id): - """Retrieve the history of a given prompt using its ID.""" - with urllib.request.urlopen(f"http://{COMFY_HOST}/history/{prompt_id}") as response: - return json.loads(response.read()) +def get_history(prompt_id, job_id=None): + """Get history for a prompt from ComfyUI. + + Args: + prompt_id: ID of the prompt + job_id: ID of the job for logging + + Returns: + dict: History from ComfyUI API + """ + # Set job ID for logging if provided + if job_id: + job_id_filter.set_job_id(job_id) + + response = requests.get(f"http://{COMFY_HOST}/history/{prompt_id}") + response.raise_for_status() + return response.json() def base64_encode(img_path): """Returns base64 encoded image.""" @@ -169,7 +252,21 @@ def base64_encode(img_path): return base64.b64encode(image_file.read()).decode("utf-8") def process_output_images(outputs, job_id): - """Process generated images and return as S3 URL or base64.""" + """Process output images from ComfyUI. + + Args: + outputs: Dictionary of outputs from ComfyUI + job_id: ID of the job + + Returns: + dict: Status of processing with image URLs + """ + # Set job ID for logging + job_id_filter.set_job_id(job_id) + + if not outputs: + return {"status": "error", "message": "No outputs found"} + COMFY_OUTPUT_PATH = os.environ.get("COMFY_OUTPUT_PATH", "/comfyui/output") output_images = [] @@ -202,15 +299,23 @@ def process_output_images(outputs, job_id): "message": processed_images, } -def verify_lora_files(lora_info): - """Verify that downloaded lora files are accessible to ComfyUI. +def verify_lora_files(lora_info, job_id=None): + """Verify that lora files are accessible to ComfyUI. Args: lora_info: List of dictionaries with lora information + job_id: ID of the job for logging Returns: - bool: True if all files are accessible, False otherwise + bool: True if all lora files are accessible, False otherwise """ + # Set job ID for logging if provided + if job_id: + job_id_filter.set_job_id(job_id) + + if not lora_info: + return True + try: # Import ComfyUI's folder_paths module to check lora paths import folder_paths @@ -219,19 +324,35 @@ def verify_lora_files(lora_info): lora_folders = folder_paths.get_folder_paths("loras") logger.info(f"ComfyUI lora folders: {lora_folders}") - # Check if our lora directory is in the ComfyUI lora folders - our_lora_dir = "/runpod-volume/models/loras" - if our_lora_dir not in lora_folders: - logger.warning(f"Our lora directory {our_lora_dir} is not in ComfyUI's lora folders") - # Try to add it to ComfyUI's folder paths - folder_paths.add_model_folder_path("loras", our_lora_dir) - logger.info(f"Added {our_lora_dir} to ComfyUI's lora folders") - - # Verify it was added - lora_folders = folder_paths.get_folder_paths("loras") + # Check which lora directory exists and use that + LORA_DIRS = [ + "/runpod-volume/models/loras", + "/workspace/models/loras" + ] + + our_lora_dirs = [] + for dir_path in LORA_DIRS: + if os.path.exists(dir_path): + our_lora_dirs.append(dir_path) + logger.info(f"Found lora directory: {dir_path}") + + if not our_lora_dirs: + logger.warning("No lora directories found") + return False + + # Check if our lora directories are in ComfyUI's lora folders + for our_lora_dir in our_lora_dirs: if our_lora_dir not in lora_folders: - logger.error(f"Failed to add {our_lora_dir} to ComfyUI's lora folders") - return False + logger.warning(f"Our lora directory {our_lora_dir} is not in ComfyUI's lora folders") + # Try to add it to ComfyUI's folder paths + folder_paths.add_model_folder_path("loras", our_lora_dir) + logger.info(f"Added {our_lora_dir} to ComfyUI's lora folders") + + # Verify it was added + lora_folders = folder_paths.get_folder_paths("loras") + if our_lora_dir not in lora_folders: + logger.error(f"Failed to add {our_lora_dir} to ComfyUI's lora folders") + return False # Check if each downloaded lora file exists and is accessible for lora in lora_info: @@ -239,7 +360,21 @@ def verify_lora_files(lora_info): lora_path = lora.get("path") if not os.path.exists(lora_path): logger.error(f"Downloaded lora file does not exist: {lora_path}") - return False + + # Check if the file exists in the other lora directory + filename = os.path.basename(lora_path) + found = False + for our_lora_dir in our_lora_dirs: + if our_lora_dir not in lora_path: # Check if it's a different directory + alt_path = os.path.join(our_lora_dir, filename) + if os.path.exists(alt_path): + logger.info(f"Found lora file in alternate location: {alt_path}") + lora["path"] = alt_path # Update the path + found = True + break + + if not found: + return False # Check if the file is readable try: @@ -261,6 +396,10 @@ def verify_lora_files(lora_info): def handler(job): """Main handler for processing a job.""" job_id = job.get("id", str(uuid.uuid4())) # Use provided job ID or generate one + + # Set job ID for logging + job_id_filter.set_job_id(job_id) + job_input = job["input"] lora_file_paths = [] # Track downloaded lora files for cleanup start_time = datetime.utcnow().isoformat() @@ -280,7 +419,7 @@ def handler(job): update_redis(job_id, "NOT_STARTED", workflow=workflow) # Check ComfyUI availability - if not check_server(f"http://{COMFY_HOST}", COMFY_API_AVAILABLE_MAX_RETRIES, COMFY_API_AVAILABLE_INTERVAL_MS): + if not check_server(f"http://{COMFY_HOST}", COMFY_API_AVAILABLE_MAX_RETRIES, COMFY_API_AVAILABLE_INTERVAL_MS, job_id): error = "ComfyUI API unavailable" update_redis(job_id, "FAILED", error=error) @@ -295,7 +434,7 @@ def handler(job): return {"error": error, "job_id": job_id} # Upload images if provided - upload_result = upload_images(images) + upload_result = upload_images(images, job_id) if upload_result["status"] == "error": update_redis(job_id, "FAILED", error=upload_result["message"]) @@ -315,7 +454,7 @@ def handler(job): if loras: logger.info(f"Processing {len(loras)} loras") update_redis(job_id, "PROCESSING_LORAS") - lora_result = download_lora_files(loras) + lora_result = download_lora_files(loras, job_id) # Only track paths of downloaded loras, not local ones lora_file_paths = lora_result.get("file_paths", []) @@ -332,7 +471,7 @@ def handler(job): # Verify lora files are accessible to ComfyUI if lora_info: - if not verify_lora_files(lora_info): + if not verify_lora_files(lora_info, job_id): logger.warning("Some lora files may not be accessible to ComfyUI") update_redis(job_id, "LORA_VERIFICATION_WARNING", error="Some lora files may not be accessible to ComfyUI") else: @@ -399,7 +538,7 @@ def handler(job): # Queue the workflow try: update_redis(job_id, "IN_QUEUE") - queued_workflow = queue_workflow(workflow) + queued_workflow = queue_workflow(workflow, job_id) prompt_id = queued_workflow["prompt_id"] logger.info(f"Queued workflow with prompt ID {prompt_id}") except Exception as e: @@ -408,7 +547,7 @@ def handler(job): # Clean up lora files if any were downloaded if lora_file_paths: - cleanup_result = cleanup_lora_files(lora_file_paths) + cleanup_result = cleanup_lora_files(lora_file_paths, job_id) logger.info(f"Cleaned up lora files after error: {cleanup_result['message']}") logger.info(f"Deleted {len(cleanup_result.get('details', []))} files, skipped {len(cleanup_result.get('skipped_files', []))} files") @@ -428,7 +567,7 @@ def handler(job): retries = 0 try: while retries < COMFY_POLLING_MAX_RETRIES: - history = get_history(prompt_id) + history = get_history(prompt_id, job_id) if prompt_id in history and history[prompt_id].get("outputs"): break time.sleep(COMFY_POLLING_INTERVAL_MS / 1000) @@ -439,7 +578,7 @@ def handler(job): # Clean up lora files if any were downloaded if lora_file_paths: - cleanup_result = cleanup_lora_files(lora_file_paths) + cleanup_result = cleanup_lora_files(lora_file_paths, job_id) logger.info(f"Cleaned up lora files after timeout: {cleanup_result['message']}") logger.info(f"Deleted {len(cleanup_result.get('details', []))} files, skipped {len(cleanup_result.get('skipped_files', []))} files") @@ -460,7 +599,7 @@ def handler(job): # Clean up lora files if any were downloaded if lora_file_paths: - cleanup_result = cleanup_lora_files(lora_file_paths) + cleanup_result = cleanup_lora_files(lora_file_paths, job_id) logger.info(f"Cleaned up lora files after polling error: {cleanup_result['message']}") logger.info(f"Deleted {len(cleanup_result.get('details', []))} files, skipped {len(cleanup_result.get('skipped_files', []))} files") @@ -481,7 +620,7 @@ def handler(job): # Clean up lora files if any were downloaded cleanup_info = {} if lora_file_paths: - cleanup_result = cleanup_lora_files(lora_file_paths) + cleanup_result = cleanup_lora_files(lora_file_paths, job_id) logger.info(f"Cleaned up lora files after successful inference: {cleanup_result['message']}") logger.info(f"Deleted {len(cleanup_result.get('details', []))} files, skipped {len(cleanup_result.get('skipped_files', []))} files") cleanup_info = { @@ -688,6 +827,9 @@ def preload_weights(checkpoint_name): def init(): """Initialize the handler.""" + # Set a default job ID for initialization logs + job_id_filter.set_job_id('init') + # Wait for ComfyUI to start up logger.info("Waiting for ComfyUI to initialize...") time.sleep(15) # Give ComfyUI time to start @@ -774,19 +916,41 @@ def init(): logger.info(f"Preloaded {len(preloaded)} models: {preloaded}") -def download_lora_files(loras): +def download_lora_files(loras, job_id=None): """Download lora files from network paths to disk. Args: loras: List of dictionaries with 'path' and 'scale' keys + job_id: ID of the job for logging Returns: dict: Status of download operation with downloaded file paths and lora info """ + # Set job ID for logging if provided + if job_id: + job_id_filter.set_job_id(job_id) + if not loras: return {"status": "success", "message": "No loras to process", "details": [], "file_paths": [], "lora_info": []} - LORA_DIR = "/runpod-volume/models/loras" + # Check which lora directory exists and use that + LORA_DIRS = [ + "/runpod-volume/models/loras", + "/workspace/models/loras" + ] + + LORA_DIR = None + for dir_path in LORA_DIRS: + if os.path.exists(dir_path): + LORA_DIR = dir_path + logger.info(f"Using lora directory: {LORA_DIR}") + break + + if LORA_DIR is None: + # If neither exists, default to the first one and create it + LORA_DIR = LORA_DIRS[0] + logger.warning(f"No existing lora directory found, creating: {LORA_DIR}") + os.makedirs(LORA_DIR, exist_ok=True) responses = [] @@ -811,8 +975,43 @@ def download_lora_files(loras): filename = f"{filename}.safetensors" local_path = os.path.join(LORA_DIR, filename) - file_paths.append(local_path) - logger.info(f"Added downloaded lora path to file_paths: {local_path}") + + # Check if the file already exists + file_exists = os.path.exists(local_path) + + # Check if this is a protected model + is_protected = is_protected_model(filename) + + if file_exists: + logger.info(f"Lora file already exists: {local_path}") + + # Only add to file_paths for cleanup if it's not a protected model + # (i.e., only if it has "pytorch_lora_weights" in the name) + if not is_protected: + file_paths.append(local_path) + logger.info(f"Added existing lora path to file_paths for cleanup: {local_path}") + else: + logger.info(f"Not adding protected model to cleanup list: {local_path}") + + # Add to lora info with local path + lora_info.append({ + "path": local_path, + "scale": scale, + "original_path": path, + "downloaded": False, # Mark as not downloaded since it already existed + "lora_name": filename # Add lora_name for ComfyUI workflow + }) + + responses.append(f"Using existing lora file: {local_path}") + continue + + # If we get here, we need to download the file + # Only add to file_paths for cleanup if it's not a protected model + if not is_protected: + file_paths.append(local_path) + logger.info(f"Added downloaded lora path to file_paths for cleanup: {local_path}") + else: + logger.info(f"Not adding protected model to cleanup list: {local_path}") try: # Download the file @@ -877,15 +1076,39 @@ def download_lora_files(loras): "lora_info": lora_info } -def cleanup_lora_files(file_paths): +def is_protected_model(filename): + """Check if a file is a protected model that should never be deleted. + + Args: + filename: Name of the file to check + + Returns: + bool: True if the file is protected, False otherwise + """ + # Check if the file has the signature of a model that should be deleted + # Only delete files with "pytorch_lora_weights" in the name + if "pytorch_lora_weights" in filename: + logger.info(f"File {filename} is not protected (has pytorch_lora_weights in name)") + return False + + # All other models are protected + logger.info(f"File {filename} is protected (does not have pytorch_lora_weights in name)") + return True + +def cleanup_lora_files(file_paths, job_id=None): """Delete downloaded lora files after inference is complete. Args: file_paths: List of file paths to delete + job_id: ID of the job for logging Returns: dict: Status of cleanup operation """ + # Set job ID for logging if provided + if job_id: + job_id_filter.set_job_id(job_id) + if not file_paths: return {"status": "success", "message": "No lora files to clean up"} @@ -895,21 +1118,43 @@ def cleanup_lora_files(file_paths): deleted_files = [] skipped_files = [] - # Define the directory where downloaded loras are stored - LORA_DIR = "/runpod-volume/models/loras" + # Define the directories where downloaded loras are stored + LORA_DIRS = [ + "/runpod-volume/models/loras", + "/workspace/models/loras" + ] for file_path in file_paths: try: - # Only delete files that are in the download directory - # This ensures we only delete files that were downloaded from URLs - if os.path.exists(file_path) and file_path.startswith(LORA_DIR): - logger.info(f"Deleting lora file: {file_path}") + # Check if the file exists + if not os.path.exists(file_path): + reason = "non-existent file" + logger.warning(f"Skipping deletion of {reason}: {file_path}") + skipped_files.append(file_path) + continue + + # Check if the file is in one of the lora directories + is_in_lora_dir = any(file_path.startswith(lora_dir) for lora_dir in LORA_DIRS) + + # Get the filename + filename = os.path.basename(file_path) + + # Check if the file is a protected model + if is_protected_model(filename): + reason = "protected model file (does not have pytorch_lora_weights in name)" + logger.warning(f"Skipping deletion of {reason}: {file_path}") + skipped_files.append(file_path) + continue + + # Only delete files that are in the download directories AND were downloaded from URLs + # We can identify URL downloads because they were added to file_paths in download_lora_files + if is_in_lora_dir: + logger.info(f"Deleting lora file (has pytorch_lora_weights in name): {file_path}") os.remove(file_path) deleted_files.append(file_path) logger.info(f"Successfully deleted lora file: {file_path}") else: - # Skip files that don't exist or aren't in the download directory - reason = "non-existent file" if not os.path.exists(file_path) else "local file (not in download directory)" + reason = "local file (not in download directory)" logger.warning(f"Skipping deletion of {reason}: {file_path}") skipped_files.append(file_path) except Exception as e: @@ -947,6 +1192,9 @@ def call_webhook(webhook_url, job_id, status, result=None, error=None, additiona Returns: dict: Status of webhook call """ + # Set job ID for logging + job_id_filter.set_job_id(job_id) + if not webhook_url: return {"status": "skipped", "message": "No webhook URL provided"} From ab62362a1474a0f932c1d3ac6c89b2b46c5f7264 Mon Sep 17 00:00:00 2001 From: Michael Aubry Date: Thu, 6 Mar 2025 08:21:41 +0100 Subject: [PATCH 7/8] chore: trying to improve the stupid delete --- src/rp_handler.py | 376 +++++++++------------------------------------- 1 file changed, 68 insertions(+), 308 deletions(-) diff --git a/src/rp_handler.py b/src/rp_handler.py index bf9378ed..fd669d4e 100644 --- a/src/rp_handler.py +++ b/src/rp_handler.py @@ -9,43 +9,15 @@ import base64 from io import BytesIO import redis -from datetime import datetime, timedelta +from datetime import datetime import logging import uuid import sys -import traceback -from typing import Dict, List, Tuple, Any, Optional, Union # Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s [%(name)s] [%(levelname)s] [%(job_id)s] %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' -) +logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Add a filter to include job_id in log records -class JobIdFilter(logging.Filter): - """Filter that adds job_id to log records.""" - - def __init__(self, name=''): - super().__init__(name) - self.job_id = 'no_job' - - def set_job_id(self, job_id): - """Set the current job ID.""" - self.job_id = job_id - - def filter(self, record): - """Add job_id to the log record.""" - if not hasattr(record, 'job_id'): - record.job_id = self.job_id - return True - -# Create and add the filter to the logger -job_id_filter = JobIdFilter() -logger.addFilter(job_id_filter) - # Add ComfyUI directory to Python path comfy_path = "/comfyui" if comfy_path not in sys.path: @@ -78,9 +50,6 @@ def filter(self, record): def update_redis(job_id, state, workflow=None, result=None, error=None): """Update job state in Redis.""" - # Set job ID for logging - job_id_filter.set_job_id(job_id) - job = { "id": job_id, "state": state, @@ -91,9 +60,9 @@ def update_redis(job_id, state, workflow=None, result=None, error=None): } try: r.set(f"job:{job_id}", json.dumps(job)) - logger.info(f"Updated job state to {state}") + logger.info(f"Updated job {job_id} state to {state}") except Exception as e: - logger.error(f"Failed to update Redis: {str(e)}") + logger.error(f"Failed to update Redis for job {job_id}: {str(e)}") def validate_input(job_input): """Validates the input for the handler function.""" @@ -132,49 +101,22 @@ def validate_input(job_input): return {"workflow": workflow, "images": images, "loras": loras, "webhook_url": webhook_url}, None -def check_server(url, retries=500, delay=50, job_id=None): - """Check if ComfyUI API is available. - - Args: - url: URL to check - retries: Number of retries - delay: Delay between retries in milliseconds - job_id: ID of the job for logging - - Returns: - bool: True if API is available, False otherwise - """ - # Set job ID for logging if provided - if job_id: - job_id_filter.set_job_id(job_id) - - for _ in range(retries): +def check_server(url, retries=500, delay=50): + """Check if a server is reachable via HTTP GET request.""" + for i in range(retries): try: - response = requests.get(f"{url}/system_stats", timeout=5) + response = requests.get(url) if response.status_code == 200: logger.info("ComfyUI API is reachable") return True - except: + except requests.RequestException: pass time.sleep(delay / 1000) - logger.error(f"Failed to connect to {url} after {retries} attempts") return False -def upload_images(images, job_id=None): - """Upload images to ComfyUI. - - Args: - images: List of image data (base64 or URLs) - job_id: ID of the job for logging - - Returns: - dict: Status of upload operation - """ - # Set job ID for logging if provided - if job_id: - job_id_filter.set_job_id(job_id) - +def upload_images(images): + """Upload a list of base64 encoded images to the ComfyUI server.""" if not images: return {"status": "success", "message": "No images to upload", "details": []} @@ -210,41 +152,16 @@ def upload_images(images, job_id=None): "details": responses, } -def queue_workflow(workflow, job_id=None): - """Queue a workflow in ComfyUI. - - Args: - workflow: Workflow to queue - job_id: ID of the job for logging - - Returns: - dict: Response from ComfyUI API - """ - # Set job ID for logging if provided - if job_id: - job_id_filter.set_job_id(job_id) - - response = requests.post(f"http://{COMFY_HOST}/prompt", json=workflow) - response.raise_for_status() - return response.json() +def queue_workflow(workflow): + """Queue a workflow to be processed by ComfyUI.""" + data = json.dumps({"prompt": workflow}).encode("utf-8") + req = urllib.request.Request(f"http://{COMFY_HOST}/prompt", data=data) + return json.loads(urllib.request.urlopen(req).read()) -def get_history(prompt_id, job_id=None): - """Get history for a prompt from ComfyUI. - - Args: - prompt_id: ID of the prompt - job_id: ID of the job for logging - - Returns: - dict: History from ComfyUI API - """ - # Set job ID for logging if provided - if job_id: - job_id_filter.set_job_id(job_id) - - response = requests.get(f"http://{COMFY_HOST}/history/{prompt_id}") - response.raise_for_status() - return response.json() +def get_history(prompt_id): + """Retrieve the history of a given prompt using its ID.""" + with urllib.request.urlopen(f"http://{COMFY_HOST}/history/{prompt_id}") as response: + return json.loads(response.read()) def base64_encode(img_path): """Returns base64 encoded image.""" @@ -252,21 +169,7 @@ def base64_encode(img_path): return base64.b64encode(image_file.read()).decode("utf-8") def process_output_images(outputs, job_id): - """Process output images from ComfyUI. - - Args: - outputs: Dictionary of outputs from ComfyUI - job_id: ID of the job - - Returns: - dict: Status of processing with image URLs - """ - # Set job ID for logging - job_id_filter.set_job_id(job_id) - - if not outputs: - return {"status": "error", "message": "No outputs found"} - + """Process generated images and return as S3 URL or base64.""" COMFY_OUTPUT_PATH = os.environ.get("COMFY_OUTPUT_PATH", "/comfyui/output") output_images = [] @@ -299,23 +202,15 @@ def process_output_images(outputs, job_id): "message": processed_images, } -def verify_lora_files(lora_info, job_id=None): - """Verify that lora files are accessible to ComfyUI. +def verify_lora_files(lora_info): + """Verify that downloaded lora files are accessible to ComfyUI. Args: lora_info: List of dictionaries with lora information - job_id: ID of the job for logging Returns: - bool: True if all lora files are accessible, False otherwise + bool: True if all files are accessible, False otherwise """ - # Set job ID for logging if provided - if job_id: - job_id_filter.set_job_id(job_id) - - if not lora_info: - return True - try: # Import ComfyUI's folder_paths module to check lora paths import folder_paths @@ -324,35 +219,19 @@ def verify_lora_files(lora_info, job_id=None): lora_folders = folder_paths.get_folder_paths("loras") logger.info(f"ComfyUI lora folders: {lora_folders}") - # Check which lora directory exists and use that - LORA_DIRS = [ - "/runpod-volume/models/loras", - "/workspace/models/loras" - ] - - our_lora_dirs = [] - for dir_path in LORA_DIRS: - if os.path.exists(dir_path): - our_lora_dirs.append(dir_path) - logger.info(f"Found lora directory: {dir_path}") - - if not our_lora_dirs: - logger.warning("No lora directories found") - return False - - # Check if our lora directories are in ComfyUI's lora folders - for our_lora_dir in our_lora_dirs: + # Check if our lora directory is in the ComfyUI lora folders + our_lora_dir = "/runpod-volume/models/loras" + if our_lora_dir not in lora_folders: + logger.warning(f"Our lora directory {our_lora_dir} is not in ComfyUI's lora folders") + # Try to add it to ComfyUI's folder paths + folder_paths.add_model_folder_path("loras", our_lora_dir) + logger.info(f"Added {our_lora_dir} to ComfyUI's lora folders") + + # Verify it was added + lora_folders = folder_paths.get_folder_paths("loras") if our_lora_dir not in lora_folders: - logger.warning(f"Our lora directory {our_lora_dir} is not in ComfyUI's lora folders") - # Try to add it to ComfyUI's folder paths - folder_paths.add_model_folder_path("loras", our_lora_dir) - logger.info(f"Added {our_lora_dir} to ComfyUI's lora folders") - - # Verify it was added - lora_folders = folder_paths.get_folder_paths("loras") - if our_lora_dir not in lora_folders: - logger.error(f"Failed to add {our_lora_dir} to ComfyUI's lora folders") - return False + logger.error(f"Failed to add {our_lora_dir} to ComfyUI's lora folders") + return False # Check if each downloaded lora file exists and is accessible for lora in lora_info: @@ -360,21 +239,7 @@ def verify_lora_files(lora_info, job_id=None): lora_path = lora.get("path") if not os.path.exists(lora_path): logger.error(f"Downloaded lora file does not exist: {lora_path}") - - # Check if the file exists in the other lora directory - filename = os.path.basename(lora_path) - found = False - for our_lora_dir in our_lora_dirs: - if our_lora_dir not in lora_path: # Check if it's a different directory - alt_path = os.path.join(our_lora_dir, filename) - if os.path.exists(alt_path): - logger.info(f"Found lora file in alternate location: {alt_path}") - lora["path"] = alt_path # Update the path - found = True - break - - if not found: - return False + return False # Check if the file is readable try: @@ -396,10 +261,6 @@ def verify_lora_files(lora_info, job_id=None): def handler(job): """Main handler for processing a job.""" job_id = job.get("id", str(uuid.uuid4())) # Use provided job ID or generate one - - # Set job ID for logging - job_id_filter.set_job_id(job_id) - job_input = job["input"] lora_file_paths = [] # Track downloaded lora files for cleanup start_time = datetime.utcnow().isoformat() @@ -419,7 +280,7 @@ def handler(job): update_redis(job_id, "NOT_STARTED", workflow=workflow) # Check ComfyUI availability - if not check_server(f"http://{COMFY_HOST}", COMFY_API_AVAILABLE_MAX_RETRIES, COMFY_API_AVAILABLE_INTERVAL_MS, job_id): + if not check_server(f"http://{COMFY_HOST}", COMFY_API_AVAILABLE_MAX_RETRIES, COMFY_API_AVAILABLE_INTERVAL_MS): error = "ComfyUI API unavailable" update_redis(job_id, "FAILED", error=error) @@ -434,7 +295,7 @@ def handler(job): return {"error": error, "job_id": job_id} # Upload images if provided - upload_result = upload_images(images, job_id) + upload_result = upload_images(images) if upload_result["status"] == "error": update_redis(job_id, "FAILED", error=upload_result["message"]) @@ -454,7 +315,7 @@ def handler(job): if loras: logger.info(f"Processing {len(loras)} loras") update_redis(job_id, "PROCESSING_LORAS") - lora_result = download_lora_files(loras, job_id) + lora_result = download_lora_files(loras) # Only track paths of downloaded loras, not local ones lora_file_paths = lora_result.get("file_paths", []) @@ -471,7 +332,7 @@ def handler(job): # Verify lora files are accessible to ComfyUI if lora_info: - if not verify_lora_files(lora_info, job_id): + if not verify_lora_files(lora_info): logger.warning("Some lora files may not be accessible to ComfyUI") update_redis(job_id, "LORA_VERIFICATION_WARNING", error="Some lora files may not be accessible to ComfyUI") else: @@ -538,7 +399,7 @@ def handler(job): # Queue the workflow try: update_redis(job_id, "IN_QUEUE") - queued_workflow = queue_workflow(workflow, job_id) + queued_workflow = queue_workflow(workflow) prompt_id = queued_workflow["prompt_id"] logger.info(f"Queued workflow with prompt ID {prompt_id}") except Exception as e: @@ -547,7 +408,7 @@ def handler(job): # Clean up lora files if any were downloaded if lora_file_paths: - cleanup_result = cleanup_lora_files(lora_file_paths, job_id) + cleanup_result = cleanup_lora_files(lora_file_paths) logger.info(f"Cleaned up lora files after error: {cleanup_result['message']}") logger.info(f"Deleted {len(cleanup_result.get('details', []))} files, skipped {len(cleanup_result.get('skipped_files', []))} files") @@ -567,7 +428,7 @@ def handler(job): retries = 0 try: while retries < COMFY_POLLING_MAX_RETRIES: - history = get_history(prompt_id, job_id) + history = get_history(prompt_id) if prompt_id in history and history[prompt_id].get("outputs"): break time.sleep(COMFY_POLLING_INTERVAL_MS / 1000) @@ -578,7 +439,7 @@ def handler(job): # Clean up lora files if any were downloaded if lora_file_paths: - cleanup_result = cleanup_lora_files(lora_file_paths, job_id) + cleanup_result = cleanup_lora_files(lora_file_paths) logger.info(f"Cleaned up lora files after timeout: {cleanup_result['message']}") logger.info(f"Deleted {len(cleanup_result.get('details', []))} files, skipped {len(cleanup_result.get('skipped_files', []))} files") @@ -599,7 +460,7 @@ def handler(job): # Clean up lora files if any were downloaded if lora_file_paths: - cleanup_result = cleanup_lora_files(lora_file_paths, job_id) + cleanup_result = cleanup_lora_files(lora_file_paths) logger.info(f"Cleaned up lora files after polling error: {cleanup_result['message']}") logger.info(f"Deleted {len(cleanup_result.get('details', []))} files, skipped {len(cleanup_result.get('skipped_files', []))} files") @@ -620,7 +481,7 @@ def handler(job): # Clean up lora files if any were downloaded cleanup_info = {} if lora_file_paths: - cleanup_result = cleanup_lora_files(lora_file_paths, job_id) + cleanup_result = cleanup_lora_files(lora_file_paths) logger.info(f"Cleaned up lora files after successful inference: {cleanup_result['message']}") logger.info(f"Deleted {len(cleanup_result.get('details', []))} files, skipped {len(cleanup_result.get('skipped_files', []))} files") cleanup_info = { @@ -827,9 +688,6 @@ def preload_weights(checkpoint_name): def init(): """Initialize the handler.""" - # Set a default job ID for initialization logs - job_id_filter.set_job_id('init') - # Wait for ComfyUI to start up logger.info("Waiting for ComfyUI to initialize...") time.sleep(15) # Give ComfyUI time to start @@ -916,41 +774,19 @@ def init(): logger.info(f"Preloaded {len(preloaded)} models: {preloaded}") -def download_lora_files(loras, job_id=None): +def download_lora_files(loras): """Download lora files from network paths to disk. Args: loras: List of dictionaries with 'path' and 'scale' keys - job_id: ID of the job for logging Returns: dict: Status of download operation with downloaded file paths and lora info """ - # Set job ID for logging if provided - if job_id: - job_id_filter.set_job_id(job_id) - if not loras: return {"status": "success", "message": "No loras to process", "details": [], "file_paths": [], "lora_info": []} - # Check which lora directory exists and use that - LORA_DIRS = [ - "/runpod-volume/models/loras", - "/workspace/models/loras" - ] - - LORA_DIR = None - for dir_path in LORA_DIRS: - if os.path.exists(dir_path): - LORA_DIR = dir_path - logger.info(f"Using lora directory: {LORA_DIR}") - break - - if LORA_DIR is None: - # If neither exists, default to the first one and create it - LORA_DIR = LORA_DIRS[0] - logger.warning(f"No existing lora directory found, creating: {LORA_DIR}") - + LORA_DIR = "/runpod-volume/models/loras" os.makedirs(LORA_DIR, exist_ok=True) responses = [] @@ -975,43 +811,8 @@ def download_lora_files(loras, job_id=None): filename = f"{filename}.safetensors" local_path = os.path.join(LORA_DIR, filename) - - # Check if the file already exists - file_exists = os.path.exists(local_path) - - # Check if this is a protected model - is_protected = is_protected_model(filename) - - if file_exists: - logger.info(f"Lora file already exists: {local_path}") - - # Only add to file_paths for cleanup if it's not a protected model - # (i.e., only if it has "pytorch_lora_weights" in the name) - if not is_protected: - file_paths.append(local_path) - logger.info(f"Added existing lora path to file_paths for cleanup: {local_path}") - else: - logger.info(f"Not adding protected model to cleanup list: {local_path}") - - # Add to lora info with local path - lora_info.append({ - "path": local_path, - "scale": scale, - "original_path": path, - "downloaded": False, # Mark as not downloaded since it already existed - "lora_name": filename # Add lora_name for ComfyUI workflow - }) - - responses.append(f"Using existing lora file: {local_path}") - continue - - # If we get here, we need to download the file - # Only add to file_paths for cleanup if it's not a protected model - if not is_protected: - file_paths.append(local_path) - logger.info(f"Added downloaded lora path to file_paths for cleanup: {local_path}") - else: - logger.info(f"Not adding protected model to cleanup list: {local_path}") + file_paths.append(local_path) + logger.info(f"Added downloaded lora path to file_paths: {local_path}") try: # Download the file @@ -1076,39 +877,15 @@ def download_lora_files(loras, job_id=None): "lora_info": lora_info } -def is_protected_model(filename): - """Check if a file is a protected model that should never be deleted. - - Args: - filename: Name of the file to check - - Returns: - bool: True if the file is protected, False otherwise - """ - # Check if the file has the signature of a model that should be deleted - # Only delete files with "pytorch_lora_weights" in the name - if "pytorch_lora_weights" in filename: - logger.info(f"File {filename} is not protected (has pytorch_lora_weights in name)") - return False - - # All other models are protected - logger.info(f"File {filename} is protected (does not have pytorch_lora_weights in name)") - return True - -def cleanup_lora_files(file_paths, job_id=None): +def cleanup_lora_files(file_paths): """Delete downloaded lora files after inference is complete. Args: file_paths: List of file paths to delete - job_id: ID of the job for logging Returns: dict: Status of cleanup operation """ - # Set job ID for logging if provided - if job_id: - job_id_filter.set_job_id(job_id) - if not file_paths: return {"status": "success", "message": "No lora files to clean up"} @@ -1118,43 +895,29 @@ def cleanup_lora_files(file_paths, job_id=None): deleted_files = [] skipped_files = [] - # Define the directories where downloaded loras are stored - LORA_DIRS = [ - "/runpod-volume/models/loras", - "/workspace/models/loras" - ] + # Define the directory where downloaded loras are stored + LORA_DIR = "/runpod-volume/models/loras" for file_path in file_paths: try: - # Check if the file exists - if not os.path.exists(file_path): - reason = "non-existent file" - logger.warning(f"Skipping deletion of {reason}: {file_path}") - skipped_files.append(file_path) - continue + # Only delete files that are in the download directory and have "pytorch_lora_weights.safetensors" in the filename + if (os.path.exists(file_path) and + file_path.startswith(LORA_DIR) and + "pytorch_lora_weights.safetensors" in file_path): - # Check if the file is in one of the lora directories - is_in_lora_dir = any(file_path.startswith(lora_dir) for lora_dir in LORA_DIRS) - - # Get the filename - filename = os.path.basename(file_path) - - # Check if the file is a protected model - if is_protected_model(filename): - reason = "protected model file (does not have pytorch_lora_weights in name)" - logger.warning(f"Skipping deletion of {reason}: {file_path}") - skipped_files.append(file_path) - continue - - # Only delete files that are in the download directories AND were downloaded from URLs - # We can identify URL downloads because they were added to file_paths in download_lora_files - if is_in_lora_dir: - logger.info(f"Deleting lora file (has pytorch_lora_weights in name): {file_path}") + logger.info(f"Deleting lora file: {file_path}") os.remove(file_path) deleted_files.append(file_path) logger.info(f"Successfully deleted lora file: {file_path}") else: - reason = "local file (not in download directory)" + # Skip files that don't meet the criteria + if not os.path.exists(file_path): + reason = "non-existent file" + elif not file_path.startswith(LORA_DIR): + reason = "local file (not in download directory)" + else: + reason = "filename does not contain 'pytorch_lora_weights.safetensors'" + logger.warning(f"Skipping deletion of {reason}: {file_path}") skipped_files.append(file_path) except Exception as e: @@ -1173,7 +936,7 @@ def cleanup_lora_files(file_paths, job_id=None): return { "status": "success", - "message": "All lora files cleaned up successfully", + "message": "All matching lora files cleaned up successfully", "details": deleted_files, "skipped_files": skipped_files } @@ -1192,9 +955,6 @@ def call_webhook(webhook_url, job_id, status, result=None, error=None, additiona Returns: dict: Status of webhook call """ - # Set job ID for logging - job_id_filter.set_job_id(job_id) - if not webhook_url: return {"status": "skipped", "message": "No webhook URL provided"} From 986dc87c2241c94127df0e8c840eedf05dab179b Mon Sep 17 00:00:00 2001 From: Michael Aubry Date: Wed, 19 Mar 2025 12:51:48 +0100 Subject: [PATCH 8/8] chore: do not preload models --- src/rp_handler.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/rp_handler.py b/src/rp_handler.py index fd669d4e..53ea8338 100644 --- a/src/rp_handler.py +++ b/src/rp_handler.py @@ -752,8 +752,8 @@ def init(): # Models to preload models_to_preload = [ - "realvisxlV40_v40LightningBakedvae.safetensors", - "pixelArtDiffusionXL_pixelWorld.safetensors" + # "realvisxlV40_v40LightningBakedvae.safetensors", + # "pixelArtDiffusionXL_pixelWorld.safetensors" ] # Check if FLUX environment variable is set to true @@ -762,17 +762,17 @@ def init(): logger.info("FLUX environment variable is set to true, adding flux1-dev.safetensors to preload list") # Preload models - preloaded = [] - for checkpoint_name in models_to_preload: - try: - result = preload_weights(checkpoint_name) - if result is not None: - preloaded.append(checkpoint_name) - logger.info(f"Successfully preloaded: {checkpoint_name}") - except Exception as e: - logger.error(f"Failed to preload {checkpoint_name}: {str(e)}") + # preloaded = [] + # for checkpoint_name in models_to_preload: + # try: + # result = preload_weights(checkpoint_name) + # if result is not None: + # preloaded.append(checkpoint_name) + # logger.info(f"Successfully preloaded: {checkpoint_name}") + # except Exception as e: + # logger.error(f"Failed to preload {checkpoint_name}: {str(e)}") - logger.info(f"Preloaded {len(preloaded)} models: {preloaded}") + # logger.info(f"Preloaded {len(preloaded)} models: {preloaded}") def download_lora_files(loras): """Download lora files from network paths to disk.