From 1091083511a7abd6dc7429e04bc36966a01ada20 Mon Sep 17 00:00:00 2001 From: Taylor Finley Date: Thu, 26 Jun 2025 10:35:16 -1000 Subject: [PATCH] add --no-skip option to overwrite, append, or prepend to existing caption files. --- scripts/batch-caption.py | 135 +++++++++++++++++++++++++++++---------- 1 file changed, 100 insertions(+), 35 deletions(-) diff --git a/scripts/batch-caption.py b/scripts/batch-caption.py index ca5e58e..a8170c0 100755 --- a/scripts/batch-caption.py +++ b/scripts/batch-caption.py @@ -44,6 +44,7 @@ def none_or_type(value, desired_type): parser.add_argument("--model", type=str, default="fancyfeast/llama-joycaption-beta-one-hf-llava", help="Model to use") parser.add_argument("--prepend", type=str, default="", help="String to prepend to all captions") parser.add_argument("--append", type=str, default="", help="String to append to all captions") +parser.add_argument("--no-skip", choices=['overwrite', 'append', 'prepend'], help="How to handle existing caption files") PIL.Image.MAX_IMAGE_PIXELS = 933120000 # Quiets Pillow from giving warnings on really large images (WARNING: Exposes a risk of DoS from malicious images) @@ -73,9 +74,10 @@ def main(): logging.warning("No images found") return logging.info(f"Found {len(image_paths)} images") - - # Ignore all images that already have captions - image_paths = [path for path in image_paths if not Path(path).with_suffix(".txt").exists()] + + # Ignore all images that already have captions (unless --no-skip is specified) + if args.no_skip is None: + image_paths = [path for path in image_paths if not Path(path).with_suffix(".txt").exists()] # Load JoyCaption tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True) @@ -129,7 +131,7 @@ def main(): captions = [c.strip() for c in captions] for path, caption in zip(batch['paths'], captions): - write_caption(Path(path), args.prepend + caption + args.append) + write_caption(Path(path), args.prepend + caption + args.append, args.no_skip) pbar.update(len(captions)) @@ -140,36 +142,99 @@ def trim_off_prompt(input_ids: list[int], eoh_id: int, eot_id: int) -> list[int] i = input_ids.index(eoh_id) except ValueError: break - + input_ids = input_ids[i + 1:] - + # Trim off the end try: i = input_ids.index(eot_id) except ValueError: return input_ids - + return input_ids[:i] -def write_caption(image_path: Path, caption: str): +def write_caption(image_path: Path, caption: str, no_skip_mode: str | None): caption_path = image_path.with_suffix(".txt") - try: - f = os.open(caption_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL) # Write-only, create if not exist, fail if exists - except FileExistsError: - logging.warning(f"Caption file '{caption_path}' already exists") - return - except Exception as e: - logging.error(f"Failed to open caption file '{caption_path}': {e}") - return - - try: - os.write(f, caption.encode("utf-8")) - os.close(f) - except Exception as e: - logging.error(f"Failed to write caption to '{caption_path}': {e}") - return + if no_skip_mode is None: + # Original behavior - fail if file exists + try: + f = os.open(caption_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL) # Write-only, create if not exist, fail if exists + except FileExistsError: + logging.warning(f"Caption file '{caption_path}' already exists") + return + except Exception as e: + logging.error(f"Failed to open caption file '{caption_path}': {e}") + return + + try: + os.write(f, caption.encode("utf-8")) + os.close(f) + except Exception as e: + logging.error(f"Failed to write caption to '{caption_path}': {e}") + return + else: + # Handle --no-skip modes + if no_skip_mode == 'overwrite': + # Simply overwrite the file + try: + with open(caption_path, 'w', encoding='utf-8') as f: + f.write(caption) + except Exception as e: + logging.error(f"Failed to write caption to '{caption_path}': {e}") + return + elif no_skip_mode == 'append': + # Read existing content, append new caption + existing_content = "" + if caption_path.exists(): + try: + with open(caption_path, 'r', encoding='utf-8') as f: + existing_content = f.read().strip() + except Exception as e: + logging.error(f"Failed to read existing caption from '{caption_path}': {e}") + return + + # Ensure existing content ends with ". " + if existing_content and not existing_content.endswith('. '): + if existing_content.endswith('.'): + existing_content += ' ' + else: + existing_content += '. ' + + final_content = existing_content + caption + try: + with open(caption_path, 'w', encoding='utf-8') as f: + f.write(final_content) + except Exception as e: + logging.error(f"Failed to write caption to '{caption_path}': {e}") + return + elif no_skip_mode == 'prepend': + # Read existing content, prepend new caption + existing_content = "" + if caption_path.exists(): + try: + with open(caption_path, 'r', encoding='utf-8') as f: + existing_content = f.read().strip() + except Exception as e: + logging.error(f"Failed to read existing caption from '{caption_path}': {e}") + return + + # Ensure new caption ends with ". " + new_caption = caption + if new_caption and not new_caption.endswith('. '): + if new_caption.endswith('.'): + new_caption += ' ' + else: + new_caption += '. ' + + final_content = new_caption + existing_content + try: + with open(caption_path, 'w', encoding='utf-8') as f: + f.write(final_content) + except Exception as e: + logging.error(f"Failed to write caption to '{caption_path}': {e}") + return class ImageDataset(Dataset): @@ -180,10 +245,10 @@ def __init__(self, prompts: list[Prompt], paths: list[Path], tokenizer: PreTrain self.image_token_id = image_token_id self.image_seq_length = image_seq_length self.pad_token_id = tokenizer.pad_token_id - + def __len__(self): return len(self.paths) - + def __getitem__(self, idx: int) -> dict: path = self.paths[idx] @@ -231,7 +296,7 @@ def __getitem__(self, idx: int) -> dict: input_tokens.extend([self.image_token_id] * self.image_seq_length) else: input_tokens.append(token) - + input_ids = torch.tensor(input_tokens, dtype=torch.long) attention_mask = torch.ones_like(input_ids) @@ -273,15 +338,15 @@ def parse_prompts(prompt_str: str | None, prompt_file: str | None) -> list[Promp if prompt_str is not None: return [Prompt(prompt=prompt_str, weight=1.0)] - + if prompt_file is None: raise ValueError("Must specify either --prompt or --prompt-file") - + data = json.loads(Path(prompt_file).read_text()) if not isinstance(data, list): raise ValueError("Expected JSON file to contain a list of prompts") - + prompts = [] for item in data: @@ -291,25 +356,25 @@ def parse_prompts(prompt_str: str | None, prompt_file: str | None) -> list[Promp prompts.append(Prompt(prompt=item["prompt"], weight=item["weight"])) else: raise ValueError(f"Invalid prompt in JSON file. Should be either a string or an object with 'prompt' and 'weight' fields: {item}") - + if len(prompts) == 0: raise ValueError("No prompts found in JSON file") - + if sum(p.weight for p in prompts) <= 0.0: raise ValueError("Prompt weights must sum to a positive number") - + return prompts def find_images(glob: str | None, filelist: str | Path | None) -> list[Path]: if glob is None and filelist is None: raise ValueError("Must specify either --glob or --filelist") - + paths = [] if glob is not None: paths.extend(Path(".").glob(glob)) - + if filelist is not None: paths.extend((Path(line.strip()) for line in Path(filelist).read_text().strip().splitlines() if line.strip() != "")) @@ -317,4 +382,4 @@ def find_images(glob: str | None, filelist: str | Path | None) -> list[Path]: if __name__ == "__main__": - main() + main() \ No newline at end of file