Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,12 @@ poetry.lock

# PyCharm
.idea

# Wandb
wandb/

# Checkpoints
run/

# Evaluation
evaluation/
7 changes: 6 additions & 1 deletion examples/ntc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@ def get_config():
return ml_collections.ConfigDict(dict(
label="base configuration",
run=0,
wandb_project="Image Compression", # Wandb project name

debug_nans=False,
checkify=False,

lmbda=8.0,
log_sigma=4.0,
learning_rate=1e-4,
temperature=float("inf"),
temperature=float("inf"), # Initial temperature for dynamic schedule
max_temp=1.0, # Maximum temperature for dynamic schedule
min_temp=0.2, # Minimum temperature for dynamic schedule
bound_epoch=200, # Epoch boundary for temperature and learning rate reduction
dynamic_t=False, # Whether to use dynamic temperature schedule

num_epochs=1000,
num_steps_per_epoch=1000,
Expand Down
7 changes: 5 additions & 2 deletions examples/ntc/ntc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import jax
from jax import numpy as jnp

from codex.loss import wasserstein

Array = jax.Array


Expand Down Expand Up @@ -223,8 +225,9 @@ def __call__(self, x, rng=None, t=None):

x_rec = self.synthesis(y)
x_rec = x_rec[:, : x.shape[-2], : x.shape[-1]]

distortion = jnp.square(x - x_rec).sum() / num_pixels
# distortion = jnp.square(x - x_rec).sum() / num_pixels
log2_sigma = jnp.full((x.shape[-2], x.shape[-1]), 4.0) # TODO: log2_sigma should be dynamic
distortion = wasserstein.vgg16_wasserstein_distortion(x, x_rec, log2_sigma=log2_sigma)

return x_rec, dict(
rate=rate,
Expand Down
146 changes: 144 additions & 2 deletions examples/ntc/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Runs an NTC training loop."""

import os
import sys
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from absl import app
from absl import flags
from absl import logging
Expand All @@ -10,20 +12,75 @@
import tensorflow as tf
import tensorflow_datasets as tfds

import train_lib
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
sys.path.insert(0, project_root)


from examples.ntc import train_lib
from codex.loss import pretrained_features

# pyright: reportUnusedCallResult=false
config_flags.DEFINE_config_file("config")
flags.DEFINE_string(
"checkpoint_path",
"/tmp/train",
"./run/train/",
"Directory where to write checkpoints.",
)
flags.DEFINE_string(
"start_path",
None,
"Directory to read initial checkpoint from.",
)
flags.DEFINE_float(
"lmbda",
None,
"Override lambda value from config. If not set, uses config.lmbda.",
)
flags.DEFINE_float(
"lr",
None,
"Override learning rate from config. If not set, uses config.learning_rate.",
)
flags.DEFINE_integer(
"batch_size",
None,
"Override batch size from config. If not set, uses config.batch_size.",
)
flags.DEFINE_float(
"temperature",
None,
"Override temperature from config. If not set, uses config.temperature.",
)
flags.DEFINE_integer(
"y_channels",
None,
"Override y_channels from config. If not set, uses config.model_kwargs[model_cls].y_channels.",
)
flags.DEFINE_boolean(
"dynamic_t",
None,
"Override dynamic_t from config. If not set, uses config.dynamic_t.",
)
flags.DEFINE_string(
"wandb_project",
None,
"Override wandb_project from config. If not set, uses config.wandb_project.",
)
flags.DEFINE_float(
"max_temp",
None,
"Override max_temp from config. If not set, uses config.max_temp.",
)
flags.DEFINE_float(
"min_temp",
None,
"Override min_temp from config. If not set, uses config.min_temp.",
)
flags.DEFINE_integer(
"bound_epoch",
None,
"Override bound_epoch from config. If not set, uses config.bound_epoch.",
)

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -68,6 +125,86 @@ def main(_):

jax.config.update("jax_debug_nans", FLAGS.config.debug_nans)

# Override lambda value if provided via command line
if FLAGS.lmbda is not None:
original_lmbda = FLAGS.config.lmbda
FLAGS.config.lmbda = FLAGS.lmbda
logging.info(f"Overriding lambda: {original_lmbda} -> {FLAGS.lmbda}")
else:
logging.info(f"Using lambda from config: {FLAGS.config.lmbda}")

# Override learning rate if provided via command line
if FLAGS.lr is not None:
original_lr = FLAGS.config.learning_rate
FLAGS.config.learning_rate = FLAGS.lr
logging.info(f"Overriding learning rate: {original_lr} -> {FLAGS.lr}")
else:
logging.info(f"Using learning rate from config: {FLAGS.config.learning_rate}")

if FLAGS.batch_size is not None:
original_batch_size = FLAGS.config.batch_size
FLAGS.config.batch_size = FLAGS.batch_size
logging.info(f"Overriding batch size: {original_batch_size} -> {FLAGS.batch_size}")
else:
logging.info(f"Using batch size from config: {FLAGS.config.batch_size}")

if FLAGS.temperature is not None:
original_temperature = FLAGS.config.temperature
FLAGS.config.temperature = FLAGS.temperature
logging.info(f"Overriding temperature: {original_temperature} -> {FLAGS.temperature}")
else:
logging.info(f"Using temperature from config: {FLAGS.config.temperature}")

if FLAGS.y_channels is not None:
model_cls = FLAGS.config.model_cls
original_y_channels = FLAGS.config.model_kwargs[model_cls]["y_channels"]
FLAGS.config.model_kwargs[model_cls]["y_channels"] = FLAGS.y_channels
logging.info(f"Overriding y_channels for {model_cls}: {original_y_channels} -> {FLAGS.y_channels}")
else:
model_cls = FLAGS.config.model_cls
logging.info(f"Using y_channels from config for {model_cls}: {FLAGS.config.model_kwargs[model_cls]['y_channels']}")

# Override dynamic_t if provided via command line
if FLAGS.dynamic_t is not None:
original_dynamic_t = FLAGS.config.dynamic_t
FLAGS.config.dynamic_t = FLAGS.dynamic_t
logging.info(f"Overriding dynamic_t: {original_dynamic_t} -> {FLAGS.dynamic_t}")
else:
logging.info(f"Using dynamic_t from config: {FLAGS.config.dynamic_t}")

# Override wandb_project if provided via command line
if FLAGS.wandb_project is not None:
original_wandb_project = FLAGS.config.wandb_project
FLAGS.config.wandb_project = FLAGS.wandb_project
logging.info(f"Overriding wandb_project: {original_wandb_project} -> {FLAGS.wandb_project}")
else:
logging.info(f"Using wandb_project from config: {FLAGS.config.wandb_project}")

# Override max_temp if provided via command line
if FLAGS.max_temp is not None:
original_max_temp = FLAGS.config.max_temp
FLAGS.config.max_temp = FLAGS.max_temp
logging.info(f"Overriding max_temp: {original_max_temp} -> {FLAGS.max_temp}")
else:
logging.info(f"Using max_temp from config: {FLAGS.config.max_temp}")

# Override min_temp if provided via command line
if FLAGS.min_temp is not None:
original_min_temp = FLAGS.config.min_temp
FLAGS.config.min_temp = FLAGS.min_temp
logging.info(f"Overriding min_temp: {original_min_temp} -> {FLAGS.min_temp}")
else:
logging.info(f"Using min_temp from config: {FLAGS.config.min_temp}")

# Override bound_epoch if provided via command line
if FLAGS.bound_epoch is not None:
original_bound_epoch = FLAGS.config.bound_epoch
FLAGS.config.bound_epoch = FLAGS.bound_epoch
logging.info(f"Overriding bound_epoch: {original_bound_epoch} -> {FLAGS.bound_epoch}")
else:
logging.info(f"Using bound_epoch from config: {FLAGS.config.bound_epoch}")


host_count = jax.process_count()
# host_id = jax.process_index()
local_device_count = jax.local_device_count()
Expand All @@ -86,6 +223,11 @@ def main(_):
)
train_iterator = train_set.as_numpy_iterator()

# Load VGG16 model for Wasserstein distortion calculation
logging.info("Loading VGG16 model for perceptual loss...")
pretrained_features.load_vgg16_model()
logging.info("VGG16 model loaded successfully.")

train_lib.train(
FLAGS.config,
FLAGS.checkpoint_path,
Expand Down
Loading