Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 100 additions & 35 deletions scripts/batch-caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def none_or_type(value, desired_type):
parser.add_argument("--model", type=str, default="fancyfeast/llama-joycaption-beta-one-hf-llava", help="Model to use")
parser.add_argument("--prepend", type=str, default="", help="String to prepend to all captions")
parser.add_argument("--append", type=str, default="", help="String to append to all captions")
parser.add_argument("--no-skip", choices=['overwrite', 'append', 'prepend'], help="How to handle existing caption files")


PIL.Image.MAX_IMAGE_PIXELS = 933120000 # Quiets Pillow from giving warnings on really large images (WARNING: Exposes a risk of DoS from malicious images)
Expand Down Expand Up @@ -73,9 +74,10 @@ def main():
logging.warning("No images found")
return
logging.info(f"Found {len(image_paths)} images")

# Ignore all images that already have captions
image_paths = [path for path in image_paths if not Path(path).with_suffix(".txt").exists()]

# Ignore all images that already have captions (unless --no-skip is specified)
if args.no_skip is None:
image_paths = [path for path in image_paths if not Path(path).with_suffix(".txt").exists()]

# Load JoyCaption
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
Expand Down Expand Up @@ -129,7 +131,7 @@ def main():
captions = [c.strip() for c in captions]

for path, caption in zip(batch['paths'], captions):
write_caption(Path(path), args.prepend + caption + args.append)
write_caption(Path(path), args.prepend + caption + args.append, args.no_skip)
pbar.update(len(captions))


Expand All @@ -140,36 +142,99 @@ def trim_off_prompt(input_ids: list[int], eoh_id: int, eot_id: int) -> list[int]
i = input_ids.index(eoh_id)
except ValueError:
break

input_ids = input_ids[i + 1:]

# Trim off the end
try:
i = input_ids.index(eot_id)
except ValueError:
return input_ids

return input_ids[:i]


def write_caption(image_path: Path, caption: str):
def write_caption(image_path: Path, caption: str, no_skip_mode: str | None):
caption_path = image_path.with_suffix(".txt")

try:
f = os.open(caption_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL) # Write-only, create if not exist, fail if exists
except FileExistsError:
logging.warning(f"Caption file '{caption_path}' already exists")
return
except Exception as e:
logging.error(f"Failed to open caption file '{caption_path}': {e}")
return

try:
os.write(f, caption.encode("utf-8"))
os.close(f)
except Exception as e:
logging.error(f"Failed to write caption to '{caption_path}': {e}")
return
if no_skip_mode is None:
# Original behavior - fail if file exists
try:
f = os.open(caption_path, os.O_WRONLY | os.O_CREAT | os.O_EXCL) # Write-only, create if not exist, fail if exists
except FileExistsError:
logging.warning(f"Caption file '{caption_path}' already exists")
return
except Exception as e:
logging.error(f"Failed to open caption file '{caption_path}': {e}")
return

try:
os.write(f, caption.encode("utf-8"))
os.close(f)
except Exception as e:
logging.error(f"Failed to write caption to '{caption_path}': {e}")
return
else:
# Handle --no-skip modes
if no_skip_mode == 'overwrite':
# Simply overwrite the file
try:
with open(caption_path, 'w', encoding='utf-8') as f:
f.write(caption)
except Exception as e:
logging.error(f"Failed to write caption to '{caption_path}': {e}")
return
elif no_skip_mode == 'append':
# Read existing content, append new caption
existing_content = ""
if caption_path.exists():
try:
with open(caption_path, 'r', encoding='utf-8') as f:
existing_content = f.read().strip()
except Exception as e:
logging.error(f"Failed to read existing caption from '{caption_path}': {e}")
return

# Ensure existing content ends with ". "
if existing_content and not existing_content.endswith('. '):
if existing_content.endswith('.'):
existing_content += ' '
else:
existing_content += '. '

final_content = existing_content + caption
try:
with open(caption_path, 'w', encoding='utf-8') as f:
f.write(final_content)
except Exception as e:
logging.error(f"Failed to write caption to '{caption_path}': {e}")
return
elif no_skip_mode == 'prepend':
# Read existing content, prepend new caption
existing_content = ""
if caption_path.exists():
try:
with open(caption_path, 'r', encoding='utf-8') as f:
existing_content = f.read().strip()
except Exception as e:
logging.error(f"Failed to read existing caption from '{caption_path}': {e}")
return

# Ensure new caption ends with ". "
new_caption = caption
if new_caption and not new_caption.endswith('. '):
if new_caption.endswith('.'):
new_caption += ' '
else:
new_caption += '. '

final_content = new_caption + existing_content
try:
with open(caption_path, 'w', encoding='utf-8') as f:
f.write(final_content)
except Exception as e:
logging.error(f"Failed to write caption to '{caption_path}': {e}")
return


class ImageDataset(Dataset):
Expand All @@ -180,10 +245,10 @@ def __init__(self, prompts: list[Prompt], paths: list[Path], tokenizer: PreTrain
self.image_token_id = image_token_id
self.image_seq_length = image_seq_length
self.pad_token_id = tokenizer.pad_token_id

def __len__(self):
return len(self.paths)

def __getitem__(self, idx: int) -> dict:
path = self.paths[idx]

Expand Down Expand Up @@ -231,7 +296,7 @@ def __getitem__(self, idx: int) -> dict:
input_tokens.extend([self.image_token_id] * self.image_seq_length)
else:
input_tokens.append(token)

input_ids = torch.tensor(input_tokens, dtype=torch.long)
attention_mask = torch.ones_like(input_ids)

Expand Down Expand Up @@ -273,15 +338,15 @@ def parse_prompts(prompt_str: str | None, prompt_file: str | None) -> list[Promp

if prompt_str is not None:
return [Prompt(prompt=prompt_str, weight=1.0)]

if prompt_file is None:
raise ValueError("Must specify either --prompt or --prompt-file")

data = json.loads(Path(prompt_file).read_text())

if not isinstance(data, list):
raise ValueError("Expected JSON file to contain a list of prompts")

prompts = []

for item in data:
Expand All @@ -291,30 +356,30 @@ def parse_prompts(prompt_str: str | None, prompt_file: str | None) -> list[Promp
prompts.append(Prompt(prompt=item["prompt"], weight=item["weight"]))
else:
raise ValueError(f"Invalid prompt in JSON file. Should be either a string or an object with 'prompt' and 'weight' fields: {item}")

if len(prompts) == 0:
raise ValueError("No prompts found in JSON file")

if sum(p.weight for p in prompts) <= 0.0:
raise ValueError("Prompt weights must sum to a positive number")

return prompts


def find_images(glob: str | None, filelist: str | Path | None) -> list[Path]:
if glob is None and filelist is None:
raise ValueError("Must specify either --glob or --filelist")

paths = []

if glob is not None:
paths.extend(Path(".").glob(glob))

if filelist is not None:
paths.extend((Path(line.strip()) for line in Path(filelist).read_text().strip().splitlines() if line.strip() != ""))

return paths


if __name__ == "__main__":
main()
main()