diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 81bf87313ea3..5579ae86f6fe 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -245,6 +245,24 @@ def parse_args(): " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) + parser.add_argument( + "--log_name", + type=str, + default=None, + required=False, + help=( + "Name of log to identify experiment in reporting tool." + ), + ) + parser.add_argument( + "--log_group", + type=str, + default=None, + required=False, + help=( + "Name of log group to aggregate experiments in reporting tool." + ), + ) parser.add_argument( "--mixed_precision", type=str, @@ -644,7 +662,10 @@ def collate_fn(examples): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - accelerator.init_trackers("text2image-fine-tune", config=vars(args)) + init_kwargs = {"wandb":{"settings":{"console": "off"}, + "name":f"{args.log_name}", + "group":f"{args.log_group}"}} + accelerator.init_trackers("text2image-fine-tune", config=vars(args), init_kwargs=init_kwargs) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps