diff --git a/.gitignore b/.gitignore index ac43932..b64444d 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,12 @@ poetry.lock # PyCharm .idea + +# Wandb +wandb/ + +# Checkpoints +run/ + +# Evaluation +evaluation/ diff --git a/examples/ntc/config.py b/examples/ntc/config.py index 594536e..2bb1fb2 100644 --- a/examples/ntc/config.py +++ b/examples/ntc/config.py @@ -8,6 +8,7 @@ def get_config(): return ml_collections.ConfigDict(dict( label="base configuration", run=0, + wandb_project="Image Compression", # Wandb project name debug_nans=False, checkify=False, @@ -15,7 +16,11 @@ def get_config(): lmbda=8.0, log_sigma=4.0, learning_rate=1e-4, - temperature=float("inf"), + temperature=float("inf"), # Initial temperature for dynamic schedule + max_temp=1.0, # Maximum temperature for dynamic schedule + min_temp=0.2, # Minimum temperature for dynamic schedule + bound_epoch=200, # Epoch boundary for temperature and learning rate reduction + dynamic_t=False, # Whether to use dynamic temperature schedule num_epochs=1000, num_steps_per_epoch=1000, diff --git a/examples/ntc/ntc.py b/examples/ntc/ntc.py index f26318d..7b35bc5 100644 --- a/examples/ntc/ntc.py +++ b/examples/ntc/ntc.py @@ -31,6 +31,8 @@ import jax from jax import numpy as jnp +from codex.loss import wasserstein + Array = jax.Array @@ -223,8 +225,9 @@ def __call__(self, x, rng=None, t=None): x_rec = self.synthesis(y) x_rec = x_rec[:, : x.shape[-2], : x.shape[-1]] - - distortion = jnp.square(x - x_rec).sum() / num_pixels + # distortion = jnp.square(x - x_rec).sum() / num_pixels + log2_sigma = jnp.full((x.shape[-2], x.shape[-1]), 4.0) # TODO: log2_sigma should be dynamic + distortion = wasserstein.vgg16_wasserstein_distortion(x, x_rec, log2_sigma=log2_sigma) return x_rec, dict( rate=rate, diff --git a/examples/ntc/train.py b/examples/ntc/train.py index f19a6bd..3ecb20e 100644 --- a/examples/ntc/train.py +++ b/examples/ntc/train.py @@ -1,6 +1,8 @@ """Runs an NTC training loop.""" import os +import sys +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" from absl import app from absl import flags from absl import logging @@ -10,13 +12,18 @@ import tensorflow as tf import tensorflow_datasets as tfds -import train_lib +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) +sys.path.insert(0, project_root) + + +from examples.ntc import train_lib +from codex.loss import pretrained_features # pyright: reportUnusedCallResult=false config_flags.DEFINE_config_file("config") flags.DEFINE_string( "checkpoint_path", - "/tmp/train", + "./run/train/", "Directory where to write checkpoints.", ) flags.DEFINE_string( @@ -24,6 +31,56 @@ None, "Directory to read initial checkpoint from.", ) +flags.DEFINE_float( + "lmbda", + None, + "Override lambda value from config. If not set, uses config.lmbda.", +) +flags.DEFINE_float( + "lr", + None, + "Override learning rate from config. If not set, uses config.learning_rate.", +) +flags.DEFINE_integer( + "batch_size", + None, + "Override batch size from config. If not set, uses config.batch_size.", +) +flags.DEFINE_float( + "temperature", + None, + "Override temperature from config. If not set, uses config.temperature.", +) +flags.DEFINE_integer( + "y_channels", + None, + "Override y_channels from config. If not set, uses config.model_kwargs[model_cls].y_channels.", +) +flags.DEFINE_boolean( + "dynamic_t", + None, + "Override dynamic_t from config. If not set, uses config.dynamic_t.", +) +flags.DEFINE_string( + "wandb_project", + None, + "Override wandb_project from config. If not set, uses config.wandb_project.", +) +flags.DEFINE_float( + "max_temp", + None, + "Override max_temp from config. If not set, uses config.max_temp.", +) +flags.DEFINE_float( + "min_temp", + None, + "Override min_temp from config. If not set, uses config.min_temp.", +) +flags.DEFINE_integer( + "bound_epoch", + None, + "Override bound_epoch from config. If not set, uses config.bound_epoch.", +) FLAGS = flags.FLAGS @@ -68,6 +125,86 @@ def main(_): jax.config.update("jax_debug_nans", FLAGS.config.debug_nans) + # Override lambda value if provided via command line + if FLAGS.lmbda is not None: + original_lmbda = FLAGS.config.lmbda + FLAGS.config.lmbda = FLAGS.lmbda + logging.info(f"Overriding lambda: {original_lmbda} -> {FLAGS.lmbda}") + else: + logging.info(f"Using lambda from config: {FLAGS.config.lmbda}") + + # Override learning rate if provided via command line + if FLAGS.lr is not None: + original_lr = FLAGS.config.learning_rate + FLAGS.config.learning_rate = FLAGS.lr + logging.info(f"Overriding learning rate: {original_lr} -> {FLAGS.lr}") + else: + logging.info(f"Using learning rate from config: {FLAGS.config.learning_rate}") + + if FLAGS.batch_size is not None: + original_batch_size = FLAGS.config.batch_size + FLAGS.config.batch_size = FLAGS.batch_size + logging.info(f"Overriding batch size: {original_batch_size} -> {FLAGS.batch_size}") + else: + logging.info(f"Using batch size from config: {FLAGS.config.batch_size}") + + if FLAGS.temperature is not None: + original_temperature = FLAGS.config.temperature + FLAGS.config.temperature = FLAGS.temperature + logging.info(f"Overriding temperature: {original_temperature} -> {FLAGS.temperature}") + else: + logging.info(f"Using temperature from config: {FLAGS.config.temperature}") + + if FLAGS.y_channels is not None: + model_cls = FLAGS.config.model_cls + original_y_channels = FLAGS.config.model_kwargs[model_cls]["y_channels"] + FLAGS.config.model_kwargs[model_cls]["y_channels"] = FLAGS.y_channels + logging.info(f"Overriding y_channels for {model_cls}: {original_y_channels} -> {FLAGS.y_channels}") + else: + model_cls = FLAGS.config.model_cls + logging.info(f"Using y_channels from config for {model_cls}: {FLAGS.config.model_kwargs[model_cls]['y_channels']}") + + # Override dynamic_t if provided via command line + if FLAGS.dynamic_t is not None: + original_dynamic_t = FLAGS.config.dynamic_t + FLAGS.config.dynamic_t = FLAGS.dynamic_t + logging.info(f"Overriding dynamic_t: {original_dynamic_t} -> {FLAGS.dynamic_t}") + else: + logging.info(f"Using dynamic_t from config: {FLAGS.config.dynamic_t}") + + # Override wandb_project if provided via command line + if FLAGS.wandb_project is not None: + original_wandb_project = FLAGS.config.wandb_project + FLAGS.config.wandb_project = FLAGS.wandb_project + logging.info(f"Overriding wandb_project: {original_wandb_project} -> {FLAGS.wandb_project}") + else: + logging.info(f"Using wandb_project from config: {FLAGS.config.wandb_project}") + + # Override max_temp if provided via command line + if FLAGS.max_temp is not None: + original_max_temp = FLAGS.config.max_temp + FLAGS.config.max_temp = FLAGS.max_temp + logging.info(f"Overriding max_temp: {original_max_temp} -> {FLAGS.max_temp}") + else: + logging.info(f"Using max_temp from config: {FLAGS.config.max_temp}") + + # Override min_temp if provided via command line + if FLAGS.min_temp is not None: + original_min_temp = FLAGS.config.min_temp + FLAGS.config.min_temp = FLAGS.min_temp + logging.info(f"Overriding min_temp: {original_min_temp} -> {FLAGS.min_temp}") + else: + logging.info(f"Using min_temp from config: {FLAGS.config.min_temp}") + + # Override bound_epoch if provided via command line + if FLAGS.bound_epoch is not None: + original_bound_epoch = FLAGS.config.bound_epoch + FLAGS.config.bound_epoch = FLAGS.bound_epoch + logging.info(f"Overriding bound_epoch: {original_bound_epoch} -> {FLAGS.bound_epoch}") + else: + logging.info(f"Using bound_epoch from config: {FLAGS.config.bound_epoch}") + + host_count = jax.process_count() # host_id = jax.process_index() local_device_count = jax.local_device_count() @@ -86,6 +223,11 @@ def main(_): ) train_iterator = train_set.as_numpy_iterator() + # Load VGG16 model for Wasserstein distortion calculation + logging.info("Loading VGG16 model for perceptual loss...") + pretrained_features.load_vgg16_model() + logging.info("VGG16 model loaded successfully.") + train_lib.train( FLAGS.config, FLAGS.checkpoint_path, diff --git a/examples/ntc/train_lib.py b/examples/ntc/train_lib.py index b600dee..d923e65 100644 --- a/examples/ntc/train_lib.py +++ b/examples/ntc/train_lib.py @@ -7,8 +7,11 @@ import equinox as eqx import jax import optax +import wandb +from tqdm import tqdm -import ntc +# import ntc +from examples.ntc import ntc @eqx.filter_jit @@ -16,9 +19,9 @@ def evaluate(model, x): return model(x, None, None) -def save_state(path, model, epoch, opt_state): +def save_state(path, model, epoch, opt_state, config): state = (model, epoch, opt_state) - fn_state = f"{path}/state.eqx" + fn_state = f"{path}/state_{config.lmbda}.eqx" with open(fn_state, "wb") as f: eqx.tree_serialise_leaves(f, state) @@ -57,12 +60,65 @@ def train(config, checkpoint_path, train_iterator, rng, start_path=None): if start_path is None: start_path = checkpoint_path - lr_schedule = optax.schedules.piecewise_constant_schedule(config.learning_rate, {}) - t_schedule = optax.schedules.piecewise_constant_schedule(config.temperature, {}) - optimizer = optax.adam(learning_rate=lr_schedule) + # Initialize wandb + wandb.init( + project=config.wandb_project, + name=f"{config.model_cls}_lambda{config.lmbda}_lr{config.learning_rate}_adam", + config={ + "learning_rate": config.learning_rate, + "batch_size": config.batch_size, + "patch_size": config.patch_size, + "model": config.model_cls, + "lambda": config.lmbda, + "temperature": config.temperature, + "optimizer": "Adam", + "lr_schedule": "piecewise_constant", + } + ) + + # Create temperature schedule based on dynamic_t configuration + if hasattr(config, 'dynamic_t') and config.dynamic_t: + # Dynamic temperature schedule: linear decrease from max_temp to min_temp over first 70 epochs + # Using piecewise_interpolate_schedule for JAX-compatible conditional logic + temperature_schedule = optax.schedules.piecewise_interpolate_schedule( + init_value=config.max_temp, + boundaries_and_scales={ + config.bound_epoch * config.num_steps_per_epoch: config.min_temp / config.max_temp, # At bound_epoch, scale to min_temp + }, + interpolate_type='linear' # Linear interpolation between boundaries + ) + logging.info(f"Using dynamic temperature schedule: {config.max_temp} -> {config.min_temp} over first {config.bound_epoch} epochs, then fixed at {config.min_temp}") + else: + # Fixed temperature schedule using config.temperature + temperature_schedule = optax.schedules.piecewise_constant_schedule(config.temperature, {}) + logging.info(f"Using fixed temperature schedule: {config.temperature}") + + # Initialize optimizer with initial learning rate + optimizer = optax.adam(learning_rate=config.learning_rate) os.makedirs(checkpoint_path, exist_ok=True) + # Save configuration parameters to config.txt file + config_file_path = os.path.join(checkpoint_path, "config.txt") + with open(config_file_path, "w") as f: + f.write("Configuration Parameters:\n") + f.write("=" * 50 + "\n\n") + + # Save all config parameters except lmbda + for key, value in config.items(): + if key != "lmbda": # Exclude lmbda as requested + if isinstance(value, dict): + f.write(f"{key}:\n") + for sub_key, sub_value in value.items(): + f.write(f" {sub_key}: {sub_value}\n") + else: + f.write(f"{key}: {value}\n") + + f.write("\n" + "=" * 50 + "\n") + f.write(f"Configuration saved at: {config_file_path}\n") + + logging.info(f"Configuration saved to: {config_file_path}") + rng, init_rng = jax.random.split(rng) model = instantiate_model(init_rng, config) opt_state = optimizer.init(eqx.filter(model, eqx.is_array)) @@ -71,17 +127,29 @@ def train(config, checkpoint_path, train_iterator, rng, start_path=None): except IOError: start_epoch = 0 + # Initialize learning rate reduction tracking + lr_reduction_counter = 0 + previous_training_loss = None + current_learning_rate = config.learning_rate + + # Initialize global step counter for temperature scheduling (not tied to optimizer state) + global_step = 0 + + # Initialize current temperature (will be updated each epoch) + current_temperature = config.max_temp if hasattr(config, 'dynamic_t') and config.dynamic_t else config.temperature + + # Initialize best validation distortion tracking + best_val_distortion = float('inf') + @eqx.filter_jit - def train_step(model, opt_state, x, rng): + def train_step(model, opt_state, x, rng, temperature): logging.info("Compiling train_step.") - lr = lr_schedule(opt_state[0].count) - t = t_schedule(opt_state[0].count) grad_fn = eqx.filter_grad(ntc.batched_loss_fn, has_aux=True) rng = jax.random.split(rng, x.shape[0]) - grads, metrics = grad_fn(model, x, config.lmbda, rng, t) + grads, metrics = grad_fn(model, x, config.lmbda, rng, temperature) update, opt_state = optimizer.update(grads, opt_state) model = eqx.apply_updates(model, update) - metrics.update(lr=lr, t=t) + metrics.update(lr=current_learning_rate, t=temperature) return model, opt_state, metrics @eqx.filter_jit @@ -94,31 +162,137 @@ def eval_step(model, x): train_step = checkify(train_step) eval_step = checkify(eval_step) - for epoch in range(start_epoch, config.num_epochs): + # Create progress bar for epochs + epoch_pbar = tqdm(range(start_epoch, config.num_epochs), + desc="Epochs", + position=0, + initial=start_epoch) + + for i, epoch in enumerate(epoch_pbar): logging.info("Starting epoch %d.", epoch) - save_state(checkpoint_path, model, epoch, opt_state) + + # Update temperature for this epoch (only for dynamic temperature) + if hasattr(config, 'dynamic_t') and config.dynamic_t: + if epoch >= config.bound_epoch: + current_temperature = config.min_temp + else: + # Calculate temperature based on current epoch progress + epoch_progress = epoch / config.bound_epoch # Progress from 0 to 1 over bound_epoch epochs + current_temperature = config.max_temp - (config.max_temp - config.min_temp) * epoch_progress + + # Log current learning rate and temperature at the start of epoch + logging.info(f"Epoch {epoch} - Current learning rate: {current_learning_rate:.2e}, temperature: {current_temperature:.4f}") metrics = collections.defaultdict(lambda: 0.0) step_metrics = dict() - - for _ in range(config.num_steps_per_epoch): + + # Create progress bar for training steps + train_pbar = tqdm(range(config.num_steps_per_epoch), desc="Training", position=1, leave=False) + for step in train_pbar: rng, train_rng = jax.random.split(rng) model, opt_state, step_metrics = train_step( - model, opt_state, next(train_iterator), train_rng + model, opt_state, next(train_iterator), train_rng, current_temperature ) for k in step_metrics: metrics[k] += float(step_metrics[k]) + # Update training progress bar with current loss, learning rate, and temperature + train_pbar.set_postfix({"loss": float(step_metrics["loss"]), "lr": f"{current_learning_rate:.2e}", "temp": f"{current_temperature:.3f}"}) + # Debug: Check if step_metrics["t"] matches current_temperature + if step == 0: # Only log once per epoch to avoid spam + logging.info(f"Debug - current_temperature: {current_temperature:.4f}, step_metrics['t']: {float(step_metrics['t']):.4f}") + global_step += 1 # Increment global step counter for k in step_metrics: metrics[k] /= config.num_steps_per_epoch - for _ in range(config.num_eval_steps): + # Create progress bar for evaluation steps + eval_pbar = tqdm(range(config.num_eval_steps), desc="Evaluation", position=1, leave=False) + for _ in eval_pbar: step_metrics = eval_step(model, next(train_iterator)) for k in step_metrics: metrics[k] += float(step_metrics[k]) + # Update evaluation progress bar with current validation loss + eval_pbar.set_postfix({"val_loss": float(step_metrics["val_loss"])}) for k in step_metrics: metrics[k] /= config.num_eval_steps - logging.info("Epoch %d metrics: %s", epoch, metrics) + # Learning rate reduction logic (after bound_epoch) + if epoch >= config.bound_epoch: + current_training_loss = metrics["loss"] + + if previous_training_loss is not None: + loss_difference = abs(current_training_loss - previous_training_loss) + + if loss_difference < 0.07: + lr_reduction_counter += 1 + logging.info(f"Epoch {epoch}: Loss difference {loss_difference:.6f} < 0.01. Counter: {lr_reduction_counter}/10") + else: + lr_reduction_counter = 0 + logging.info(f"Epoch {epoch}: Loss difference {loss_difference:.6f} >= 0.01. Resetting counter.") + + # Reduce learning rate if counter reaches 10 + if lr_reduction_counter >= 30: + new_lr = current_learning_rate / 10.0 + if new_lr >= 1e-7: # Check minimum learning rate + current_learning_rate = new_lr + # Recreate optimizer with new learning rate + optimizer = optax.adam(learning_rate=current_learning_rate) + opt_state = optimizer.init(eqx.filter(model, eqx.is_array)) + logging.info(f"Epoch {epoch}: Reducing learning rate to {current_learning_rate:.2e} (global_step: {global_step})") + else: + logging.info(f"Epoch {epoch}: Learning rate would be {new_lr:.2e} < 1e-7, keeping at {current_learning_rate:.2e}") + lr_reduction_counter = 0 # Reset counter + + previous_training_loss = current_training_loss + + # Check if current validation distortion is the best so far + current_val_distortion = metrics["val_distortion"] + + # Check if we have a new best validation distortion + if current_val_distortion < best_val_distortion: # TODO use loss compare + # We have a new best, so save the checkpoint + logging.info(f"New best validation distortion: {current_val_distortion:.6f} < {best_val_distortion:.6f}") + save_state(checkpoint_path, model, epoch, opt_state, config) + best_val_distortion = current_val_distortion + else: + logging.info(f"No improvement: val_distortion {current_val_distortion:.6f} >= best_val_distortion {best_val_distortion:.6f}") + + # Log metrics to wandb + wandb.log({ + "epoch": epoch, + "loss": metrics["loss"], + "rate": metrics["rate"], + "distortion": metrics["distortion"], + "learning_rate": current_learning_rate, + "temperature": metrics["t"], + "val_loss": metrics["val_loss"], + "val_rate": metrics["val_rate"], + "val_distortion": metrics["val_distortion"], + "best_val_distortion": best_val_distortion, + }) + + # For HyperPriorModel, log additional metrics if they exist + if "rate_y" in metrics: + wandb.log({ + "rate_y": metrics["rate_y"], + "rate_z": metrics["rate_z"], + "val_rate_y": metrics["val_rate_y"], + "val_rate_z": metrics["val_rate_z"], + }) + + # Update epoch progress bar with metrics + epoch_pbar.set_postfix({ + "loss": metrics["loss"], + "val_loss": metrics["val_loss"], + "rate": metrics["rate"], + "distortion": metrics["distortion"], + "lr": f"{current_learning_rate:.2e}", + "temp": f"{current_temperature:.3f}" + }) + + # Print metrics on separate lines + logging.info("Epoch %d metrics:", epoch) + for metric_name, metric_value in metrics.items(): + logging.info(" %s: %f", metric_name, metric_value) nan_metrics = [k for k, v in metrics.items() if math.isnan(v)] if nan_metrics: @@ -126,4 +300,6 @@ def eval_step(model, x): f"Encountered NaN in metrics: {nan_metrics}. Stopping training." ) + # Save final checkpoint save_state(checkpoint_path, model, config.num_epochs, opt_state) + wandb.finish()