From 0cc429581e42be55841c48581b0000e70cad167a Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Sat, 11 Oct 2025 11:11:42 -0400 Subject: [PATCH 01/56] Pixel shuffle rggb model code --- src/Restorer/Cond_NAF_ps.py | 333 ++++++++++++++++++++++++++++++++++++ 1 file changed, 333 insertions(+) create mode 100644 src/Restorer/Cond_NAF_ps.py diff --git a/src/Restorer/Cond_NAF_ps.py b/src/Restorer/Cond_NAF_ps.py new file mode 100644 index 0000000..b459ce6 --- /dev/null +++ b/src/Restorer/Cond_NAF_ps.py @@ -0,0 +1,333 @@ +import torch.nn.functional as F +import torch +import torch.nn as nn + +class LayerNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return ( + gx, + (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), + grad_output.sum(dim=3).sum(dim=2).sum(dim=0), + None, + ) + + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class ConditionedChannelAttention(nn.Module): + def __init__(self, dims, cat_dims): + super().__init__() + in_dim = dims + cat_dims + # self.mlp = nn.Sequential( + # nn.Linear(in_dim, int(in_dim*1.5)), + # nn.GELU(), + # nn.Dropout(0.2), + # nn.Linear(int(in_dim*1.5), dims) + # ) + self.mlp = nn.Sequential(nn.Linear(in_dim, dims)) + self.pool = nn.AdaptiveAvgPool2d(1) + + def forward(self, x, conditioning): + pool = self.pool(x) + conditioning = conditioning.unsqueeze(-1).unsqueeze(-1) + cat_channels = torch.cat([pool, conditioning], dim=1) + cat_channels = cat_channels.permute(0, 2, 3, 1) + ca = self.mlp(cat_channels).permute(0, 3, 1, 2) + + return ca + + +class NAFBlock0(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + x = self.conv4(self.norm2(y)) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return (y + x * self.gamma, cond) + +class Restorer(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + width=16, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[], + cond_input=1, + cond_output=32, + expand_dims=2, + drop_out_rate=0.0, + drop_out_rate_increment=0.0 + ): + super().__init__() + + self.expand_dims = expand_dims + self.conditioning_gen = nn.Sequential( + nn.Linear(cond_input, 64), nn.ReLU(), nn.Dropout(drop_out_rate), nn.Linear(64, cond_output), + ) + + self.intro = nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + self.ending = nn.Conv2d( + in_channels=width, + out_channels=out_channels, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + + chan = width + # for num in enc_blk_nums: + for i in range(len(enc_blk_nums)): + num = enc_blk_nums[i] + self.encoders.append( + nn.Sequential( + *[ + NAFBlock0(chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + drop_out_rate += drop_out_rate_increment + self.downs.append(nn.Conv2d(chan, 2 * chan, 2, 2)) + chan = chan * 2 + + self.middle_blks = nn.Sequential( + *[ + NAFBlock0(chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(middle_blk_num) + ] + ) + + for i in range(len(dec_blk_nums)): + num = dec_blk_nums[i] + self.ups.append( + nn.Sequential( + nn.Conv2d(chan, chan * 2, 1, bias=False), nn.PixelShuffle(2) + ) + ) + drop_out_rate -= drop_out_rate_increment + chan = chan // 2 + self.decoders.append( + nn.Sequential( + *[ + NAFBlock0(chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + + + self.padder_size = 2 ** len(self.encoders) + + def forward(self, inp, cond_in): + # Conditioning: + cond = self.conditioning_gen(cond_in) + + B, C, H, W = inp.shape + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + for encoder, down in zip(self.encoders, self.downs): + x = encoder((x, cond))[0] + encs.append(x) + x = down(x) + + x = self.middle_blks((x, cond))[0] + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + x = decoder((x, cond))[0] + + x = self.ending(x) + + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + + + +class AddPixelShuffle(nn.Module): + def __init__(self, model, in_channels=4, out_channels=3): + super().__init__() + self.model = model + self.ps = nn.PixelShuffle(2) + + def forward(self, x, iso, residual): + x = self.model(x, iso) + x = self.ps(x) + return x + residual + + + + +# Get the model +def load_model(weight_file_path): + + model = Restorer( + width=64, + enc_blk_nums = [2, 2, 4, 8], + middle_blk_num = 12, + dec_blk_nums = [2, 2, 2, 2], + cond_input = 2, + in_channels = 4, + out_channels = 12, + ) + model = AddPixelShuffle(model) + + state_dict = torch.load(weight_file_path, map_location=torch.device('cpu')) + model.load_state_dict(state_dict) + return model From c632d315793145ef5c5e0703ab6b9122721b0675 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Sun, 12 Oct 2025 01:15:26 -0400 Subject: [PATCH 02/56] SmallRawDataset for pretraining --- src/training/SmallRawDataset.py | 88 +++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 src/training/SmallRawDataset.py diff --git a/src/training/SmallRawDataset.py b/src/training/SmallRawDataset.py new file mode 100644 index 0000000..0bf1eb2 --- /dev/null +++ b/src/training/SmallRawDataset.py @@ -0,0 +1,88 @@ +import pandas as pd +import os +from torch.utils.data import Dataset +import imageio +from colour_demosaicing import ( + ROOT_RESOURCES_EXAMPLES, + demosaicing_CFA_Bayer_bilinear, + demosaicing_CFA_Bayer_Malvar2004, + demosaicing_CFA_Bayer_Menon2007, + mosaicing_CFA_Bayer) + +from src.training.utils import inverse_gamma_tone_curve, cfa_to_sparse +import numpy as np +import torch +import cv2 +from src.training.align_images import apply_alignment + + +class SmallRawDataset(Dataset): + def __init__(self, path, csv, crop_size=180, buffer=10, validation=False): + super().__init__() + self.df = pd.read_csv(csv) + self.path = path + self.crop_size = crop_size + self.buffer = buffer + self.coordinate_iso = 6400 + self.validation=validation + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + # Load images + with imageio.imopen(f"{row.bayer_path}", "r") as image_resource: + bayer_data = image_resource.read() + + with imageio.imopen(f"{row.gt_path}", "r") as image_resource: + gt_image = image_resource.read() + gt_image = gt_image/255 + bayer_data = bayer_data/255 + + aligned = apply_alignment(gt_image, row.to_dict()) + demosaiced_noisy = demosaicing_CFA_Bayer_Malvar2004(bayer_data) + + h, w, _ = gt_image.shape + + #Crop images + if not self.validation: + top = np.random.randint(0 + self.buffer, h - self.crop_size - self.buffer) + left = np.random.randint(0 + self.buffer, w - self.crop_size - self.buffer) + else: + top = (h - self.crop_size) // 2 + left = (w - self.crop_size) // 2 + + if top % 2 != 0: top = top - 1 + if left % 2 != 0: left = left - 1 + bottom = top + self.crop_size + right = left + self.crop_size + aligned = aligned[top:bottom, left:right] + gt_image = gt_image[top:bottom, left:right] + bayer_data = bayer_data[top:bottom, left:right] + h, w, _ = gt_image.shape + + # Translate to linear + gt_image = inverse_gamma_tone_curve(gt_image) + aligned = inverse_gamma_tone_curve(aligned) + bayer_data = inverse_gamma_tone_curve(bayer_data) + + demosaiced_noisy = demosaicing_CFA_Bayer_Malvar2004(bayer_data) + + aligned = aligned * demosaiced_noisy.mean() / aligned.mean() + gt_image = gt_image * demosaiced_noisy.mean() / gt_image.mean() + + sparse, _ = cfa_to_sparse(bayer_data) + rggb = bayer_data.reshape(h // 2, 2, w // 2, 2, 1).transpose(3, 1, 4, 0, 2).reshape(4, h // 2, w // 2) + + # Convert to tensors + output = { + "bayer": torch.tensor(bayer_data).to(float).clip(0,1), + "gt": torch.tensor(gt_image).to(float).permute(2, 0, 1).clip(0,1), + "aligned": torch.tensor(aligned).to(float).permute(2, 0, 1).clip(0,1), + "sparse": torch.tensor(sparse).to(float).clip(0,1), + "noisy": torch.tensor(demosaiced_noisy).to(float).permute(2, 0, 1).clip(0,1), + "rggb": torch.tensor(rggb).to(float).clip(0,1), + "conditioning": torch.tensor([row.iso/self.coordinate_iso]).to(float), + } + return output \ No newline at end of file From 67736a48c5478a2c55930640e42058946c8c4a16 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Tue, 14 Oct 2025 13:26:47 -0400 Subject: [PATCH 03/56] First version of pretraining loop --- 1_pretrain_model.ipynb | 217 ++++++++++++ src/Restorer/Cond_CHASPA.py | 456 +++++++++++++++++++++++++ src/Restorer/Cond_NAF.py | 493 ++++++++++++++++++++++++++++ src/Restorer/Cond_NAF_original.py | 393 ++++++++++++++++++++++ src/Restorer/Cond_NAF_ps.py | 59 ++-- src/Restorer/Restorer.py | 102 ++++++ src/training/RawDataset.py | 155 +++++++++ src/training/ShadowAwareLoss.py | 88 +++++ src/training/SmallRawDataset.py | 2 +- src/training/VGGFeatureExtractor.py | 81 +++++ src/training/align_images.py | 181 ++++++++++ src/training/image_utils.py | 289 ++++++++++++++++ src/training/rggb_loop.py | 120 +++++++ src/training/sparse_loop.py | 136 ++++++++ src/training/utils.py | 295 +++++++++++++++++ 15 files changed, 3031 insertions(+), 36 deletions(-) create mode 100644 1_pretrain_model.ipynb create mode 100644 src/Restorer/Cond_CHASPA.py create mode 100644 src/Restorer/Cond_NAF.py create mode 100644 src/Restorer/Cond_NAF_original.py create mode 100644 src/training/RawDataset.py create mode 100644 src/training/ShadowAwareLoss.py create mode 100644 src/training/VGGFeatureExtractor.py create mode 100644 src/training/align_images.py create mode 100644 src/training/image_utils.py create mode 100644 src/training/rggb_loop.py create mode 100644 src/training/sparse_loop.py create mode 100644 src/training/utils.py diff --git a/1_pretrain_model.ipynb b/1_pretrain_model.ipynb new file mode 100644 index 0000000..1394471 --- /dev/null +++ b/1_pretrain_model.ipynb @@ -0,0 +1,217 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6351e77", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader, random_split\n", + "import torch.nn as nn\n", + "import torch\n", + "import copy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2043dc8e", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.SmallRawDataset import SmallRawDataset\n", + "# from src.Restorer.Cond_NAF import make_model, make_full_model, make_full_model_RGGB\n", + "\n", + "from src.training.ShadowAwareLoss import ShadowAwareLoss\n", + "from src.training.VGGFeatureExtractor import VGGFeatureExtractor\n", + "from src.training.sparse_loop import train_one_epoch, visualize\n", + "from src.training.rggb_loop import train_one_epoch_rggb, visualize\n", + "\n", + "\n", + "from src.training.utils import apply_gamma_torch\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba20b866", + "metadata": {}, + "outputs": [], + "source": [ + "device= 'mps'\n", + "\n", + "batch_size = 2\n", + "lr = 1e-4 * batch_size / 4\n", + "# lr = 1e-3 * batch_size / 32\n", + "clipping = 1e-2\n", + "\n", + "num_epochs = 75\n", + "val_split = 0.2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f16fa7", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = SmallRawDataset('/Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG_high_quality/', 'align.csv', crop_size=256)\n", + "\n", + "\n", + "# Split dataset into train and val\n", + "val_size = int(len(dataset) * val_split)\n", + "train_size = len(dataset) - val_size\n", + "torch.manual_seed(42) # For reproducibility\n", + "train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n", + "# Set the validation dataset to use the same crops\n", + "val_dataset = copy.deepcopy(val_dataset)\n", + "val_dataset.dataset.validation = True\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a219ab5", + "metadata": {}, + "outputs": [], + "source": [ + "# model_name = '/Volumes/EasyStore/models/Cond_NAF_original_null.pt'\n", + "# model = make_full_model_RGGB(model_name=None)\n", + "# model = model.to(device)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7669fd6d", + "metadata": {}, + "outputs": [], + "source": [ + "# # model_name = '/Volumes/EasyStore/models/Restorer_train_vgg_relu_300_full.pt'\n", + "# # model = make_full_model(model_name=model_name)\n", + "# # model = model.to(device)\n", + "\n", + "# model_name = '/Volumes/EasyStore/models/Restorer_train_vgg_relu_300_full_RGGB.pt'\n", + "# model_name = '/Volumes/EasyStore/models/Restorer_train_vgg_relu_0_full_RGGB.pt'\n", + "\n", + "# model = make_full_model_RGGB(model_name=model_name)\n", + "# model = model.to(device)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "086fbb3a", + "metadata": {}, + "outputs": [], + "source": [ + "from src.Restorer.Cond_NAF import make_model, make_full_model, make_full_model_RGGB\n", + "\n", + "model_name = '/Volumes/EasyStore/models/Restorer_train_vgg_relu_300_full_RGGB.pt'\n", + "model_name = '/Volumes/EasyStore/models/Restorer_train_vgg_relu_0_full_RGGB.pt'\n", + "\n", + "model = make_full_model_RGGB(model_name=model_name)\n", + "model = model.to(device)\n", + "\n", + "\n", + "from src.Restorer.Cond_CHASPA import make_full_model_RGGB\n", + "\n", + "\n", + "model_name = '/Volumes/EasyStore/models/Restorer_train_vgg_relu_0_full_RGGB_CHASPA.pt'\n", + "\n", + "model = make_full_model_RGGB(model_name=model_name)\n", + "model = model.to(device)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8012a124", + "metadata": {}, + "outputs": [], + "source": [ + "# model_name = '/Volumes/EasyStore/models/Restorer_train_vgg_relu_300_full.pt'\n", + "# model = make_full_model(model_name=model_name)\n", + "# # model = model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a666123", + "metadata": {}, + "outputs": [], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6af0f3a2", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", + "\n", + "vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), \n", + " feature_layers=[14], \n", + " activation=nn.ReLU\n", + " )\n", + "vfe = vfe.to(device)\n", + "\n", + "loss_fn = ShadowAwareLoss(\n", + " alpha=0.2,\n", + " beta=5.0,\n", + " l1_weight=0.16,\n", + " ssim_weight=0.84,\n", + " tv_weight=0.0,\n", + " vgg_loss_weight=0,\n", + " apply_gamma_fn=apply_gamma_torch,\n", + " vgg_feature_extractor=vfe,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bc86a06", + "metadata": {}, + "outputs": [], + "source": [ + "for epoch in range(num_epochs):\n", + " train_one_epoch(epoch, model, optimizer, model_name, train_loader, device, loss_fn, clipping, log_interval = 10, sleep=0.0, rggb=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/Restorer/Cond_CHASPA.py b/src/Restorer/Cond_CHASPA.py new file mode 100644 index 0000000..9371a48 --- /dev/null +++ b/src/Restorer/Cond_CHASPA.py @@ -0,0 +1,456 @@ +import torch.nn.functional as F +import torch +import torch.nn as nn + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class ConditionedChannelAttention(nn.Module): + def __init__(self, dims, cat_dims): + super().__init__() + in_dim = dims + cat_dims + # self.mlp = nn.Sequential( + # nn.Linear(in_dim, int(in_dim*1.5)), + # nn.GELU(), + # nn.Dropout(0.2), + # nn.Linear(int(in_dim*1.5), dims) + # ) + self.mlp = nn.Sequential(nn.Linear(in_dim, dims)) + self.pool = nn.AdaptiveAvgPool2d(1) + + def forward(self, x, conditioning): + pool = self.pool(x) + conditioning = conditioning.unsqueeze(-1).unsqueeze(-1) + cat_channels = torch.cat([pool, conditioning], dim=1) + cat_channels = cat_channels.permute(0, 2, 3, 1) + ca = self.mlp(cat_channels).permute(0, 3, 1, 2) + + return ca + +class NKA(nn.Module): + def __init__(self, dim, channel_reduction = 8): + super().__init__() + + reduced_channels = dim // channel_reduction + self.proj_1 = nn.Conv2d(dim, reduced_channels, 1, 1, 0) + self.dwconv = nn.Conv2d(reduced_channels, reduced_channels, 3, 1, 1, groups=reduced_channels) + self.proj_2 = nn.Conv2d(reduced_channels, reduced_channels * 2, 1, 1, 0) + self.sg = SimpleGate() + self.attention = nn.Conv2d(reduced_channels, dim, 1, 1, 0) + + def forward(self, x): + B, C, H, W = x.shape + # First projection to a smaller dimension + y = self.proj_1(x) + # DW conv + attn = self.dwconv(y) + # PW to increase channel count for SG + attn = self.proj_2(attn) + # Non-linearity + attn = self.sg(attn) + # Back to original dimensions + out = x * self.attention(attn) + return out + +class CHASPABlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + + self.NKA = NKA(c) + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=c, + kernel_size=3, + padding=1, + stride=1, + groups=c, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv2 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = nn.GroupNorm(1, c) + self.norm2 = nn.GroupNorm(1, c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + x = self.norm1(x) + + # Channel Mixing + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + x = self.dropout2(x) + y = inp + x * self.beta + + #Spatial Mixing + x = self.NKA(self.norm2(y)) + x = self.conv1(x) + x = self.dropout1(x) + + + return (y + x * self.gamma, cond) + + + +class CondSEBlock(nn.Module): + def __init__(self, chan, reduction=16, cond_chan=1): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(chan + cond_chan, chan // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(chan // reduction, chan, bias=False), + nn.Sigmoid() + ) + + def forward(self, x, conditioning): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = torch.cat([y, conditioning], dim=1) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + +class CondFuser(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + x1, x2 = x.chunk(2, dim=1) + return x1 + x2 + + +class Restorer(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[], + chans = [], + cond_input=1, + cond_output=32, + expand_dims=2, + drop_out_rate=0.0, + drop_out_rate_increment=0.0, + rggb = False + ): + super().__init__() + width = chans[0] + + self.expand_dims = expand_dims + self.conditioning_gen = nn.Sequential( + nn.Linear(cond_input, 64), nn.ReLU(), nn.Dropout(drop_out_rate), nn.Linear(64, cond_output), + ) + self.rggb = rggb + if not rggb: + self.intro = nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + else: + self.intro = nn.Sequential( + + nn.Conv2d( + in_channels=in_channels, + out_channels=width * 2 ** 2, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ), + nn.PixelShuffle(2) + ) + + nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + self.ending = nn.Conv2d( + in_channels=width, + out_channels=out_channels, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + self.merges = nn.ModuleList() + + + # for num in enc_blk_nums: + for i in range(len(enc_blk_nums)): + current_chan = chans[i] + next_chan = chans[i + 1] + num = enc_blk_nums[i] + self.encoders.append( + nn.Sequential( + *[ + CHASPABlock(current_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + drop_out_rate += drop_out_rate_increment + self.downs.append(nn.Conv2d(current_chan, next_chan, 2, 2)) + + self.middle_blks = nn.Sequential( + *[ + CHASPABlock(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(middle_blk_num) + ] + ) + + for i in range(len(dec_blk_nums)): + current_chan = chans[-i-1] + next_chan = chans[-i-2] + num = dec_blk_nums[i] + self.ups.append( + nn.Sequential( + nn.Conv2d(current_chan, next_chan * 2 ** 2, 1, bias=False), nn.PixelShuffle(2) + ) + ) + drop_out_rate -= drop_out_rate_increment + self.merges.append(CondFuser(next_chan, cond_chan=cond_output)) + self.decoders.append( + nn.Sequential( + *[ + CHASPABlock(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + + + self.padder_size = 2 ** len(self.encoders) + + def forward(self, inp, cond_in): + # Conditioning: + cond = self.conditioning_gen(cond_in) + + B, C, H, W = inp.shape + if self.rggb: + H = 2 * H + W = 2 * W + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + for encoder, down in zip(self.encoders, self.downs): + x = encoder((x, cond))[0] + encs.append(x) + x = down(x) + + x = self.middle_blks((x, cond))[0] + + for decoder, up, merge, enc_skip in zip(self.decoders, self.ups, self.merges, encs[::-1]): + x = up(x) + x = merge(x, enc_skip, cond) + x = decoder((x, cond))[0] + + x = self.ending(x) + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + + +class ModelWrapper(nn.Module): + def __init__(self): + super().__init__() + self.model = Restorer( + chans = [32, 128, 256, 512, 1024], + enc_blk_nums = [1,1,3,4], + middle_blk_num = 6, + dec_blk_nums = [2, 2, 1, 1], + cond_input = 1, + in_channels = 3, + out_channels = 3, + ) + + def forward(self, x, cond, residual): + output = self.model(x, cond) + return residual + output + +def make_model(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw.pt'): + model = ModelWrapper() + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + +class ModelWrapperFull(nn.Module): + def __init__(self): + super().__init__() + self.model = Restorer( + chans = [64, 128, 256, 512, 1024], + enc_blk_nums = [1,1,3,4], + middle_blk_num = 6, + dec_blk_nums = [2, 2, 1, 1], + cond_input = 1, + in_channels = 3, + out_channels = 3, + ) + + def forward(self, x, cond, residual): + output = self.model(x, cond) + return residual + output + + +def make_full_model(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw_full.pt'): + model = ModelWrapperFull() + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + + +class ModelWrapperFullRGGB(nn.Module): + def __init__(self): + super().__init__() + self.model = Restorer( + chans = [32, 64, 128, 256, 256, 256], + enc_blk_nums = [2,2,2,3,4], + middle_blk_num = 12, + dec_blk_nums = [2, 2, 2, 2, 2], + cond_input = 1, + in_channels = 4, + out_channels = 3, + rggb=True, + ) + + def forward(self, x, cond, residual): + output = self.model(x, cond) + return residual + output + + +def make_full_model_RGGB(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw_full_RGGB.pt'): + model = ModelWrapperFullRGGB() + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + + +class ModelWrapperResidual(nn.Module): + def __init__(self): + super().__init__() + self.model = Restorer( + chans = [32, 64, 128, 256, 256, 256], + enc_blk_nums = [2,2,2,3,4], + middle_blk_num = 12, + dec_blk_nums = [2, 2, 2, 2, 2], + cond_input = 1, + in_channels = 6, + out_channels = 3, + rggb=False, + ) + + def forward(self, x, cond, residual): + output = self.model(x, cond) + return residual + output + +def make_residual_model(model_name = None): + model = ModelWrapperResidual() + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + + + + +class ModelWrapperDeep(nn.Module): + def __init__(self): + super().__init__() + self.model = Restorer( + chans = [32, 64, 128, 256, 512, 1024], + enc_blk_nums = [2,2,2,3,4], + middle_blk_num = 6, + dec_blk_nums = [2, 2, 2, 2, 2], + cond_input = 1, + in_channels = 3, + out_channels = 3, + ) + + def forward(self, x, cond, residual): + output = self.model(x, cond) + return residual + output + + +def make_deep_model(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw_deep.pt'): + model = ModelWrapperDeep() + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model \ No newline at end of file diff --git a/src/Restorer/Cond_NAF.py b/src/Restorer/Cond_NAF.py new file mode 100644 index 0000000..1372378 --- /dev/null +++ b/src/Restorer/Cond_NAF.py @@ -0,0 +1,493 @@ +import torch.nn.functional as F +import torch +import torch.nn as nn + +class LayerNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return ( + gx, + (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), + grad_output.sum(dim=3).sum(dim=2).sum(dim=0), + None, + ) + + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class ConditionedChannelAttention(nn.Module): + def __init__(self, dims, cat_dims): + super().__init__() + in_dim = dims + cat_dims + # self.mlp = nn.Sequential( + # nn.Linear(in_dim, int(in_dim*1.5)), + # nn.GELU(), + # nn.Dropout(0.2), + # nn.Linear(int(in_dim*1.5), dims) + # ) + self.mlp = nn.Sequential(nn.Linear(in_dim, dims)) + self.pool = nn.AdaptiveAvgPool2d(1) + + def forward(self, x, conditioning): + pool = self.pool(x) + conditioning = conditioning.unsqueeze(-1).unsqueeze(-1) + cat_channels = torch.cat([pool, conditioning], dim=1) + cat_channels = cat_channels.permute(0, 2, 3, 1) + ca = self.mlp(cat_channels).permute(0, 3, 1, 2) + + return ca + + +class NAFBlock0(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + x = self.conv4(self.norm2(y)) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return (y + x * self.gamma, cond) + + +class CondSEBlock(nn.Module): + def __init__(self, chan, reduction=16, cond_chan=1): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(chan + cond_chan, chan // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(chan // reduction, chan, bias=False), + nn.Sigmoid() + ) + + def forward(self, x, conditioning): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = torch.cat([y, conditioning], dim=1) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + +class CondFuser(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + x1, x2 = x.chunk(2, dim=1) + return x1 + x2 + + +class Restorer(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[], + chans = [], + cond_input=1, + cond_output=32, + expand_dims=2, + drop_out_rate=0.0, + drop_out_rate_increment=0.0, + rggb = False + ): + super().__init__() + width = chans[0] + + self.expand_dims = expand_dims + self.conditioning_gen = nn.Sequential( + nn.Linear(cond_input, 64), nn.ReLU(), nn.Dropout(drop_out_rate), nn.Linear(64, cond_output), + ) + self.rggb = rggb + if not rggb: + self.intro = nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + else: + self.intro = nn.Sequential( + + nn.Conv2d( + in_channels=in_channels, + out_channels=width * 2 ** 2, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ), + nn.PixelShuffle(2) + ) + + nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + self.ending = nn.Conv2d( + in_channels=width, + out_channels=out_channels, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + self.merges = nn.ModuleList() + + + # for num in enc_blk_nums: + for i in range(len(enc_blk_nums)): + current_chan = chans[i] + next_chan = chans[i + 1] + num = enc_blk_nums[i] + self.encoders.append( + nn.Sequential( + *[ + NAFBlock0(current_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + drop_out_rate += drop_out_rate_increment + self.downs.append(nn.Conv2d(current_chan, next_chan, 2, 2)) + + self.middle_blks = nn.Sequential( + *[ + NAFBlock0(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(middle_blk_num) + ] + ) + + for i in range(len(dec_blk_nums)): + current_chan = chans[-i-1] + next_chan = chans[-i-2] + num = dec_blk_nums[i] + self.ups.append( + nn.Sequential( + nn.Conv2d(current_chan, next_chan * 2 ** 2, 1, bias=False), nn.PixelShuffle(2) + ) + ) + drop_out_rate -= drop_out_rate_increment + self.merges.append(CondFuser(next_chan, cond_chan=cond_output)) + self.decoders.append( + nn.Sequential( + *[ + NAFBlock0(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + + + self.padder_size = 2 ** len(self.encoders) + + def forward(self, inp, cond_in): + # Conditioning: + cond = self.conditioning_gen(cond_in) + + B, C, H, W = inp.shape + if self.rggb: + H = 2 * H + W = 2 * W + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + for encoder, down in zip(self.encoders, self.downs): + x = encoder((x, cond))[0] + encs.append(x) + x = down(x) + + x = self.middle_blks((x, cond))[0] + + for decoder, up, merge, enc_skip in zip(self.decoders, self.ups, self.merges, encs[::-1]): + x = up(x) + x = merge(x, enc_skip, cond) + x = decoder((x, cond))[0] + + x = self.ending(x) + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + + +class ModelWrapper(nn.Module): + def __init__(self): + super().__init__() + self.model = Restorer( + chans = [32, 128, 256, 512, 1024], + enc_blk_nums = [1,1,3,4], + middle_blk_num = 6, + dec_blk_nums = [2, 2, 1, 1], + cond_input = 1, + in_channels = 3, + out_channels = 3, + ) + + def forward(self, x, cond, residual): + output = self.model(x, cond) + return residual + output + +def make_model(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw.pt'): + model = ModelWrapper() + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + +class ModelWrapperFull(nn.Module): + def __init__(self): + super().__init__() + self.model = Restorer( + chans = [64, 128, 256, 512, 1024], + enc_blk_nums = [1,1,3,4], + middle_blk_num = 6, + dec_blk_nums = [2, 2, 1, 1], + cond_input = 1, + in_channels = 3, + out_channels = 3, + ) + + def forward(self, x, cond, residual): + output = self.model(x, cond) + return residual + output + + +def make_full_model(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw_full.pt'): + model = ModelWrapperFull() + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + + +class ModelWrapperFullRGGB(nn.Module): + def __init__(self): + super().__init__() + self.model = Restorer( + chans = [32, 64, 128, 256, 256, 256], + enc_blk_nums = [2,2,2,3,4], + middle_blk_num = 12, + dec_blk_nums = [2, 2, 2, 2, 2], + cond_input = 1, + in_channels = 4, + out_channels = 3, + rggb=True, + ) + + def forward(self, x, cond, residual): + output = self.model(x, cond) + return residual + output + + +def make_full_model_RGGB(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw_full_RGGB.pt'): + model = ModelWrapperFullRGGB() + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + + +class ModelWrapperResidual(nn.Module): + def __init__(self): + super().__init__() + self.model = Restorer( + chans = [32, 64, 128, 256, 256, 256], + enc_blk_nums = [2,2,2,3,4], + middle_blk_num = 12, + dec_blk_nums = [2, 2, 2, 2, 2], + cond_input = 1, + in_channels = 6, + out_channels = 3, + rggb=False, + ) + + def forward(self, x, cond, residual): + output = self.model(x, cond) + return residual + output + +def make_residual_model(model_name = None): + model = ModelWrapperResidual() + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + + + + +class ModelWrapperDeep(nn.Module): + def __init__(self): + super().__init__() + self.model = Restorer( + chans = [32, 64, 128, 256, 512, 1024], + enc_blk_nums = [2,2,2,3,4], + middle_blk_num = 6, + dec_blk_nums = [2, 2, 2, 2, 2], + cond_input = 1, + in_channels = 3, + out_channels = 3, + ) + + def forward(self, x, cond, residual): + output = self.model(x, cond) + return residual + output + + +def make_deep_model(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw_deep.pt'): + model = ModelWrapperDeep() + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model \ No newline at end of file diff --git a/src/Restorer/Cond_NAF_original.py b/src/Restorer/Cond_NAF_original.py new file mode 100644 index 0000000..87185d4 --- /dev/null +++ b/src/Restorer/Cond_NAF_original.py @@ -0,0 +1,393 @@ +import torch.nn.functional as F +import torch +import torch.nn as nn + +class LayerNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return ( + gx, + (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), + grad_output.sum(dim=3).sum(dim=2).sum(dim=0), + None, + ) + + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class ConditionedChannelAttention(nn.Module): + def __init__(self, dims, cat_dims): + super().__init__() + in_dim = dims + cat_dims + # self.mlp = nn.Sequential( + # nn.Linear(in_dim, int(in_dim*1.5)), + # nn.GELU(), + # nn.Dropout(0.2), + # nn.Linear(int(in_dim*1.5), dims) + # ) + self.mlp = nn.Sequential(nn.Linear(in_dim, dims)) + self.pool = nn.AdaptiveAvgPool2d(1) + + def forward(self, x, conditioning): + pool = self.pool(x) + conditioning = conditioning.unsqueeze(-1).unsqueeze(-1) + cat_channels = torch.cat([pool, conditioning], dim=1) + cat_channels = cat_channels.permute(0, 2, 3, 1) + ca = self.mlp(cat_channels).permute(0, 3, 1, 2) + + return ca + + +class NAFBlock0(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + x = self.conv4(self.norm2(y)) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return (y + x * self.gamma, cond) + + +class CondSEBlock(nn.Module): + def __init__(self, chan, reduction=16, cond_chan=1): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(chan + cond_chan, chan // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(chan // reduction, chan, bias=False), + nn.Sigmoid() + ) + + def forward(self, x, conditioning): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = torch.cat([y, conditioning], dim=1) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + +class CondFuser(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + x1, x2 = x.chunk(2, dim=1) + return x1 + x2 + + +class Restorer(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[], + chans = [], + cond_input=1, + cond_output=32, + expand_dims=2, + drop_out_rate=0.0, + drop_out_rate_increment=0.0, + rggb = False + ): + super().__init__() + width = chans[0] + + self.expand_dims = expand_dims + self.conditioning_gen = nn.Sequential( + nn.Linear(cond_input, 64), nn.ReLU(), nn.Dropout(drop_out_rate), nn.Linear(64, cond_output), + ) + self.rggb = rggb + if not rggb: + self.intro = nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + else: + self.intro = nn.Sequential( + + nn.Conv2d( + in_channels=in_channels, + out_channels=width * 2 ** 2, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ), + nn.PixelShuffle(2) + ) + + nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + self.ending = nn.Conv2d( + in_channels=width, + out_channels=out_channels, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + # self.merges = nn.ModuleList() + + + # for num in enc_blk_nums: + for i in range(len(enc_blk_nums)): + current_chan = chans[i] + next_chan = chans[i + 1] + num = enc_blk_nums[i] + self.encoders.append( + nn.Sequential( + *[ + NAFBlock0(current_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + drop_out_rate += drop_out_rate_increment + self.downs.append(nn.Conv2d(current_chan, next_chan, 2, 2)) + + self.middle_blks = nn.Sequential( + *[ + NAFBlock0(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(middle_blk_num) + ] + ) + + for i in range(len(dec_blk_nums)): + current_chan = chans[-i-1] + next_chan = chans[-i-2] + num = dec_blk_nums[i] + self.ups.append( + nn.Sequential( + nn.Conv2d(current_chan, next_chan * 2 ** 2, 1, bias=False), nn.PixelShuffle(2) + ) + ) + drop_out_rate -= drop_out_rate_increment + # self.merges.append(CondFuser(next_chan, cond_chan=cond_output)) + self.decoders.append( + nn.Sequential( + *[ + NAFBlock0(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + + + self.padder_size = 2 ** len(self.encoders) + + def forward(self, inp, cond_in): + # Conditioning: + cond = self.conditioning_gen(cond_in) + + B, C, H, W = inp.shape + if self.rggb: + H = 2 * H + W = 2 * W + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + for encoder, down in zip(self.encoders, self.downs): + x = encoder((x, cond))[0] + encs.append(x) + x = down(x) + + x = self.middle_blks((x, cond))[0] + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + # x = merge(x, enc_skip, cond) + x = decoder((x, cond))[0] + + x = self.ending(x) + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + + +class ModelWrapperFullRGGB(nn.Module): + def __init__(self): + super().__init__() + self.model = Restorer( + chans = [64, 128, 256, 512, 1024], + enc_blk_nums = [2,2,4,8], + middle_blk_num = 12, + dec_blk_nums = [2, 2, 2, 2], + cond_input = 1, + cond_output=1, + in_channels = 4, + out_channels = 3, + rggb=True, + ) + + def forward(self, x, cond, residual): + output = self.model(x, cond) + return residual + output + + +def make_full_model_RGGB(model_name = None): + model = ModelWrapperFullRGGB() + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + +2+2+6+8+12+2+2+2+2 diff --git a/src/Restorer/Cond_NAF_ps.py b/src/Restorer/Cond_NAF_ps.py index b459ce6..f7932ec 100644 --- a/src/Restorer/Cond_NAF_ps.py +++ b/src/Restorer/Cond_NAF_ps.py @@ -2,48 +2,37 @@ import torch import torch.nn as nn -class LayerNormFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x, weight, bias, eps): - ctx.eps = eps - N, C, H, W = x.size() - mu = x.mean(1, keepdim=True) - var = (x - mu).pow(2).mean(1, keepdim=True) - y = (x - mu) / (var + eps).sqrt() - ctx.save_for_backward(y, var, weight) - y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) - return y - - @staticmethod - def backward(ctx, grad_output): - eps = ctx.eps - - N, C, H, W = grad_output.size() - y, var, weight = ctx.saved_variables - g = grad_output * weight.view(1, C, 1, 1) - mean_g = g.mean(dim=1, keepdim=True) - - mean_gy = (g * y).mean(dim=1, keepdim=True) - gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) - return ( - gx, - (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), - grad_output.sum(dim=3).sum(dim=2).sum(dim=0), - None, - ) - - class LayerNorm2d(nn.Module): def __init__(self, channels, eps=1e-6): super(LayerNorm2d, self).__init__() + # 1. Keep the weight and bias as standard nn.Parameters self.register_parameter("weight", nn.Parameter(torch.ones(channels))) self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) self.eps = eps + + # 2. REMOVE the self.weight_view and self.bias_view initializations from here + # They will be created dynamically in forward. def forward(self, x): - return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) - + # N, C, H, W = x.size() # While useful for clarity, not strictly needed for the operations + # 1. Calculate Mean (mu) and Variance (var) across the Channel dimension (1) + # Note: We are sticking to your original normalization over C (dim=1) + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + + # 2. Normalize the input + y = (x - mu) / torch.sqrt(var + self.eps) + + # 3. Create the views INSIDE the forward pass, so they are part of the traced graph + weight_view = self.weight.view(1, self.weight.size(0), 1, 1) + bias_view = self.bias.view(1, self.bias.size(0), 1, 1) + + # 4. Apply the learnable scale (weight) and shift (bias) + y = weight_view * y + bias_view + + return y + class SimpleGate(nn.Module): def forward(self, x): x1, x2 = x.chunk(2, dim=1) @@ -306,10 +295,10 @@ def __init__(self, model, in_channels=4, out_channels=3): self.model = model self.ps = nn.PixelShuffle(2) - def forward(self, x, iso, residual): + def forward(self, x, iso): x = self.model(x, iso) x = self.ps(x) - return x + residual + return x diff --git a/src/Restorer/Restorer.py b/src/Restorer/Restorer.py index 7f39126..9ebb8e8 100644 --- a/src/Restorer/Restorer.py +++ b/src/Restorer/Restorer.py @@ -453,3 +453,105 @@ def forward(self, x, iso): x = self.model(x, iso) x = self.ps(x) return x + + + + +class NAFBlockCond(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp, cond): + + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + x = self.conv4(self.norm2(y)) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma, cond diff --git a/src/training/RawDataset.py b/src/training/RawDataset.py new file mode 100644 index 0000000..49d7a1b --- /dev/null +++ b/src/training/RawDataset.py @@ -0,0 +1,155 @@ +import torch +from torch.utils.data import Dataset +import pandas as pd +import numpy as np +from matplotlib.colors import rgb_to_hsv, hsv_to_rgb +import random +import cv2 +from pathlib import Path +from RawHandler.RawHandler import RawHandler +from RawHandler.utils import linear_to_srgb, pixel_unshuffle, pixel_shuffle +from colour_demosaicing import demosaicing_CFA_Bayer_Malvar2004 +from src.training.image_utils import simulate_sparse, bilinear_demosaic + +def normalized_cross_correlation(im1, im2): + im1 = im1 - np.mean(im1) + im2 = im2 - np.mean(im2) + numerator = np.sum(im1 * im2) + denominator = np.sqrt(np.sum(im1**2) * np.sum(im2**2)) + return numerator / denominator if denominator != 0 else 0.0 + +class RawDataset(Dataset): + def __init__(self, csv_path, patch_size=256, cc_threshold=0.92, colorspace='lin_rec2020'): + self.df = pd.read_csv(csv_path) + + # Drop rows with missing 'cc' or invalid types + self.df = self.df[pd.to_numeric(self.df['cc'], errors='coerce').notnull()] + self.df['cc'] = self.df['cc'].astype(float) + + # Filter based on cc threshold + original_len = len(self.df) + self.df = self.df[self.df['cc'] >= cc_threshold] + filtered_len = len(self.df) + + print(f"[Dataset] Filtered out {original_len - filtered_len} rows below cc threshold of {cc_threshold}.") + + self.patch_size = patch_size + self.colorspace = colorspace + + def __len__(self): + return len(self.df) + + + def __getitem__(self, idx): + row = self.df.iloc[idx] + + # Load images using RawHandler + gt = RawHandler(row["gt_image"], colorspace=self.colorspace) + + # Get dimensions + H, W = gt.raw.shape # Assume same for noisy and gt + + # Random crop coordinates + if H < self.patch_size or W < self.patch_size: + raise ValueError(f"Image is smaller than patch size: {H}x{W} < {self.patch_size}") + + align_offset = 20 + x1 = random.randint(0 + align_offset, W - (self.patch_size + align_offset)) + y1 = random.randint(0 + align_offset, H - (self.patch_size + align_offset)) + return self.get_patches(idx, x1, y1) + + + def get_patches(self, idx, x1, y1): + row = self.df.iloc[idx] + + # Load images using RawHandler + noisy = RawHandler(row["noisy_image"], colorspace=self.colorspace) + gt = RawHandler(row["gt_image"], colorspace=self.colorspace) + + # Get dimensions + H, W = gt.raw.shape # Assume same for noisy and gt + + # Random crop coordinates + if H < self.patch_size or W < self.patch_size: + raise ValueError(f"Image is smaller than patch size: {H}x{W} < {self.patch_size}") + + align_offset = 20 + x2 = x1 + self.patch_size + y2 = y1 + self.patch_size + crop_dim = (y1, y2, x1, x2) + + # Convert to float32 numpy arrays + expand_crop_dim = (y1-align_offset, y2+align_offset, x1-align_offset, x2+align_offset) + noisy_patch = noisy.as_rgb(dims=expand_crop_dim, demosaicing_func=demosaicing_CFA_Bayer_Malvar2004).astype(np.float32) + noisy_patch = noisy_patch[:, align_offset:-align_offset, align_offset:-align_offset] + + align_crop_dim = (y1-align_offset, y2+align_offset, x1-align_offset, x2+align_offset) + gt_patch = gt.as_rgb(dims=align_crop_dim, demosaicing_func=demosaicing_CFA_Bayer_Malvar2004).astype(np.float32) + + # Align gt image + gt_image = gt_patch.transpose(1, 2, 0) + noisy_image = noisy_patch.transpose(1, 2, 0) + + warp_matrix = np.array([[1, 0, -row.x_warp + align_offset], + [0, 1, -row.y_warp + align_offset]]) + gt_image = cv2.warpAffine(gt_image, warp_matrix, + (noisy_image.shape[1], noisy_image.shape[0]), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP) + + ncc = normalized_cross_correlation(gt_image, noisy_image) + noise_level = (gt_image-noisy_image).std(axis=(0, 1)) + + gt_patch = gt_image.transpose(2, 0, 1) + + # Adjust brightness + seperate_channels = False + if seperate_channels: + gains = gt_patch.mean(axis=(1, 2))/noisy_patch.mean(axis=(1, 2)) + + # Known issue, since the gt will clip at 1 and adjust by 2, we might have a max val of .5 + gt_patch *= 1/gains.reshape(3, 1, 1) + gt_patch = np.clip(gt_patch, 0, 1) + else: + gains = gt_patch.mean()/noisy_patch.mean() + gt_patch *= 1/gains + gt_patch = np.clip(gt_patch, 0, 1) + + # Make sparse and bilinear output + noisy_sparse = noisy.as_sparse(dims=expand_crop_dim).astype(np.float32) + noisy_sparse = noisy_sparse[:, align_offset:-align_offset, align_offset:-align_offset] + gt_sparse, g_sparse_mask = simulate_sparse(gt_patch) + bilinear = bilinear_demosaic(noisy_sparse) + noisy_sparse = torch.from_numpy(noisy_sparse).float() + noisy_tensor = torch.from_numpy(noisy_patch).float() + bilinear_tensor = torch.from_numpy(bilinear).float() + gt_tensor = torch.from_numpy(gt_patch).float() + gt_sparse_tensor = torch.from_numpy(gt_sparse).float() + + return_dict = { + 'noisy_sparse': noisy_sparse, + 'noisy_tensor': noisy_tensor, + 'bilinear_tensor': bilinear_tensor, + 'gt': gt_tensor, + 'gt_sparse': gt_sparse_tensor, + # "noisy_rggb_tensor": noisy_rggb_tensor, + # "conditioning": conditioning, + 'idx': idx, + 'cc': row['cc'], + 'ncc': ncc, + 'noise_level': noise_level, + 'iso': row['iso'], + # 'x_warp': row['x_warp'], + # 'y_warp': row['y_warp'], + # 'gt_mean': row['gt_mean'], + # 'noisy_mean': row['noisy_mean'], + 'gains': gains, + # 'noisy_rggb_tensor': noisy_rggb_tensor + } + + using_NAF = True + if using_NAF: + noisy_rggb = noisy.as_rggb( dims=crop_dim) + noisy_rggb_tensor = torch.from_numpy(noisy_rggb).float() + conditioning = torch.tensor([float(row['iso'])/6400, 0, 0, 0]).float() + return_dict['conditioning'] = conditioning + return_dict['noisy_rggb_tensor'] = noisy_rggb_tensor + return return_dict diff --git a/src/training/ShadowAwareLoss.py b/src/training/ShadowAwareLoss.py new file mode 100644 index 0000000..447fa0a --- /dev/null +++ b/src/training/ShadowAwareLoss.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +from pytorch_msssim import ms_ssim + + +class ShadowAwareLoss(nn.Module): + def __init__(self, + alpha=0.2, + beta=5.0, + l1_weight=0.16, + ssim_weight=0.84, + tv_weight=0.0, + vgg_loss_weight=0.0, + apply_gamma_fn=None, + vgg_feature_extractor=None, + device=None): + """ + Shadow-aware image restoration loss. + + Args: + alpha: Minimum tone weight for bright pixels. + beta: Controls how quickly weight increases in shadows. + l1_weight, ssim_weight, tv_weight, vgg_loss_weight: Loss scaling factors. + apply_gamma_fn: Optional function to apply gamma correction to input tensors. + vgg_feature_extractor: Optional VGG feature extractor returning feature maps. + device: Optional device to move inputs and buffers to. + """ + super().__init__() + self.alpha = alpha + self.beta = beta + self.l1_weight = l1_weight + self.ssim_weight = ssim_weight + self.tv_weight = tv_weight + self.vgg_loss_weight = vgg_loss_weight + self.apply_gamma_fn = apply_gamma_fn + self.vfe = vgg_feature_extractor + self.device = device + + if device is not None: + self.to(device) + + def forward(self, pred, target): + """ + Args: + pred: [B, C, H, W] restored image in [0,1] + target: [B, C, H, W] ground truth in [0,1] + """ + if self.apply_gamma_fn is not None: + pred = torch.clamp(pred, 1e-6, 1.0) + target = torch.clamp(target, 1e-6, 1.0) + pred = self.apply_gamma_fn(pred).clamp(1e-6, 1) + target = self.apply_gamma_fn(target).clamp(1e-6, 1) + + # Convert to luminance (BT.709) + lum = 0.2126 * target[:, 0] + 0.7152 * target[:, 1] + 0.0722 * target[:, 2] # [B, H, W] + tone_weight = self.alpha + (1.0 - self.alpha) * torch.exp(-self.beta * lum) + tone_weight = tone_weight.unsqueeze(1) # [B, 1, H, W] + + # Weighted L1 loss + l1 = (tone_weight * torch.abs(pred - target)).mean() + + # Weighted MS-SSIM loss + ssim = 1 - ms_ssim(pred, target, data_range=1.0, size_average=True) + + # TV loss only in low-light areas + shadow_mask = (lum < 0.2).float().unsqueeze(1) + dx = torch.abs(pred[:, :, :, 1:] - pred[:, :, :, :-1]) + dy = torch.abs(pred[:, :, 1:, :] - pred[:, :, :-1, :]) + tv = ((shadow_mask[:, :, :, 1:] * dx).mean() + + (shadow_mask[:, :, 1:, :] * dy).mean()) + + # Optional VGG perceptual loss + vgg_loss_val = 0 + if self.vgg_loss_weight != 0 and self.vfe is not None: + with torch.no_grad(): + pred_features = self.vfe(pred) + target_features = self.vfe(target) + vgg_loss_val = nn.functional.mse_loss(pred_features[0], target_features[0]) + + # Combine weighted terms + total_loss = ( + self.l1_weight * l1 + + self.ssim_weight * ssim + + self.tv_weight * tv + + self.vgg_loss_weight * vgg_loss_val + ) + + return total_loss diff --git a/src/training/SmallRawDataset.py b/src/training/SmallRawDataset.py index 0bf1eb2..ff3d013 100644 --- a/src/training/SmallRawDataset.py +++ b/src/training/SmallRawDataset.py @@ -73,7 +73,7 @@ def __getitem__(self, idx): gt_image = gt_image * demosaiced_noisy.mean() / gt_image.mean() sparse, _ = cfa_to_sparse(bayer_data) - rggb = bayer_data.reshape(h // 2, 2, w // 2, 2, 1).transpose(3, 1, 4, 0, 2).reshape(4, h // 2, w // 2) + rggb =bayer_data.reshape(h // 2, 2, w // 2, 2, 1).transpose(1, 3, 4, 0, 2).reshape(4, h // 2, w // 2) # Convert to tensors output = { diff --git a/src/training/VGGFeatureExtractor.py b/src/training/VGGFeatureExtractor.py new file mode 100644 index 0000000..d9a01b4 --- /dev/null +++ b/src/training/VGGFeatureExtractor.py @@ -0,0 +1,81 @@ +import torch.nn as nn +import torch +import torch.nn.functional as F +import pytorch_msssim +import math + +def custom_weight_init(m): + if isinstance(m, nn.Conv2d): + k = m.kernel_size[0] # assuming square kernels + c_in = m.in_channels + n_l = (k ** 2) * c_in + std = math.sqrt(2.0 / n_l) + nn.init.normal_(m.weight, mean=0.0, std=std) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.Linear): + n_l = m.in_features + std = math.sqrt(2.0 / n_l) + nn.init.normal_(m.weight, mean=0.0, std=std) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + + +class VGGFeatureExtractor(nn.Module): + """ + VGG-like network for perceptual loss with runtime hot-swappable activations. + Returns features from selected layers. + """ + def __init__(self, config=None, feature_layers=None, activation=None): + super().__init__() + if config is None: + config = [(2, 64), (2, 128), (3, 256), (3, 512), (3, 512)] + + if feature_layers is None: + feature_layers = [3, 8, 15, 22, 29] + + if activation is None: + activation = lambda: nn.ReLU(inplace=False) + + layers = [] + in_channels = 3 + for num_convs, out_channels in config: + for _ in range(num_convs): + layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)) + layers.append(activation()) + in_channels = out_channels + layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) + + self.features = nn.Sequential(*layers) + self.feature_layers = feature_layers + self.activation_factory = activation # store for later swapping + self.apply(custom_weight_init) + + def set_activation(self, activation_cls, **kwargs): + """ + Replace all activation layers with new activation. + activation_cls: activation class (e.g., nn.GELU, nn.LeakyReLU) + kwargs: any keyword args for the activation class + """ + for i, layer in enumerate(self.features): + if isinstance(layer, nn.ReLU) or \ + isinstance(layer, nn.LeakyReLU) or \ + isinstance(layer, nn.GELU) or \ + isinstance(layer, nn.SiLU) or \ + isinstance(layer, nn.Identity) or \ + isinstance(layer, nn.Tanh): + self.features[i] = activation_cls(**kwargs) + # Update the stored factory for future reference + self.activation_factory = lambda: activation_cls(**kwargs) + + def forward(self, x): + outputs = [] + for i, layer in enumerate(self.features): + x = layer(x) + if i in self.feature_layers: + outputs.append(x) + return outputs + + + + diff --git a/src/training/align_images.py b/src/training/align_images.py new file mode 100644 index 0000000..ddc08ec --- /dev/null +++ b/src/training/align_images.py @@ -0,0 +1,181 @@ +import cv2 +import numpy as np +from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr +import pandas as pd +import os +from torch.utils.data import Dataset +import imageio +from colour_demosaicing import ( + ROOT_RESOURCES_EXAMPLES, + demosaicing_CFA_Bayer_bilinear, + demosaicing_CFA_Bayer_Malvar2004, + demosaicing_CFA_Bayer_Menon2007, + mosaicing_CFA_Bayer) + +import numpy as np +import torch +import cv2 + + + + +def align_clean_to_noisy(clean_img, noisy_img, refine=True, verbose=False): + """ + Aligns the clean image to the noisy image and returns: + - aligned image + - best warp matrix (2x3 affine) + - metrics dict (PSNR/SSIM before and after alignment) + """ + + # --- convert to grayscale float32 for processing --- + def to_gray_f(img): + g = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img.copy() + g = g.astype(np.float32) + g = (g - g.mean()) / (g.std() + 1e-8) + return g + + clean_gray = to_gray_f(clean_img) + noisy_gray = to_gray_f(noisy_img) + + h, w = clean_gray.shape + aligned = clean_img.copy() + + # --- PHASE CORRELATION (coarse translation) --- + shift, response = cv2.phaseCorrelate(noisy_gray, clean_gray) # (dx, dy) + dx, dy = shift + M_trans = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + + # apply translation + aligned = cv2.warpAffine(clean_img, M_trans, (w, h), + flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP, + borderMode=cv2.BORDER_REFLECT) + + # --- optional ECC refinement (affine) --- + if refine: + warp_mode = cv2.MOTION_AFFINE + warp_matrix = np.eye(2, 3, dtype=np.float32) + criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 2000, 1e-8) + try: + cc, warp_matrix = cv2.findTransformECC( + noisy_gray, # template (target) + cv2.cvtColor(aligned, cv2.COLOR_BGR2GRAY).astype(np.float32) if aligned.ndim == 3 else aligned.astype(np.float32), + warp_matrix, + warp_mode, + criteria, + None, + 5 + ) + if verbose: + print(f"ECC converged: corr={cc:.5f}") + # compose the transforms: M_total = M_ECC @ M_trans + M1 = np.vstack([M_trans, [0, 0, 1]]) + M2 = np.vstack([warp_matrix, [0, 0, 1]]) + M_total = (M2 @ M1)[:2, :] + aligned = cv2.warpAffine(clean_img, M_total, (w, h), + flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP, + borderMode=cv2.BORDER_REFLECT) + except cv2.error as e: + if verbose: + print("ECC failed:", e) + M_total = M_trans + else: + M_total = M_trans + + # --- compute metrics --- + def safe_gray(img): + return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img + + clean_g = safe_gray(clean_img) + noisy_g = safe_gray(noisy_img) + aligned_g = safe_gray(aligned) + + before_psnr = psnr(noisy_g, clean_g, data_range=255) + after_psnr = psnr(noisy_g, aligned_g, data_range=255) + before_ssim = ssim(noisy_g, clean_g, data_range=255) + after_ssim = ssim(noisy_g, aligned_g, data_range=255) + + metrics = { + "PSNR_before": before_psnr, + "PSNR_after": after_psnr, + "SSIM_before": before_ssim, + "SSIM_after": after_ssim, + "dx": dx, + "dy": dy, + "response": response, + } + + # flatten warp matrix for CSV + for i in range(2): + for j in range(3): + metrics[f"M{i}{j}"] = float(M_total[i, j]) + + return aligned, M_total, metrics + + + +def apply_alignment(img, warp_params, interpolation=cv2.INTER_LINEAR): + """ + Applies a previously estimated affine warp to an image. + warp_params: dict with keys M00..M12 or a 2x3 numpy array. + """ + if isinstance(warp_params, dict): + M = np.array([ + [warp_params["M00"], warp_params["M01"], warp_params["M02"]], + [warp_params["M10"], warp_params["M11"], warp_params["M12"]], + ], dtype=np.float32) + else: + M = np.array(warp_params, dtype=np.float32) + + h, w = img.shape[:2] + aligned = cv2.warpAffine( + img.astype(np.float32), + M, + (w, h), + flags=interpolation + cv2.WARP_INVERSE_MAP, + borderMode=cv2.BORDER_REFLECT + ) + return aligned + + + + +class AlignImages(Dataset): + def __init__(self, path, csv, crop_size=180, buffer=10, validation=False): + super().__init__() + self.df = pd.read_csv(os.path.join(path, csv)) + self.path = path + self.crop_size = crop_size + self.buffer = buffer + self.coordinate_iso = 6400 + self.validation=validation + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + + # Get Row Matrix + shape=(2,3) + cols = [f"m{i}{j}" for i in range(shape[0]) for j in range(shape[1])] + flat = np.array([row.pop(c) for c in cols], dtype=np.float32) + warp_matrix = flat.reshape(shape) + warp_matrix + + # Load images + bayer_path = f"{self.path}/{row.noisy_image}_bayer.jpg" + with imageio.imopen(bayer_path, "r") as image_resource: + bayer_data = image_resource.read() + + gt_path = f"{self.path}/{row.gt_image}.jpg" + with imageio.imopen(gt_path, "r") as image_resource: + gt_image = image_resource.read() + + demosaiced_noisy = demosaicing_CFA_Bayer_Malvar2004(bayer_data) + demosaiced_noisy = demosaiced_noisy.astype(np.uint8) + aligned, matrix, metrics = align_clean_to_noisy(gt_image, demosaiced_noisy, refine=False) + metrics['iso'] = row.iso + metrics['std'] = (demosaiced_noisy.astype(int) - aligned.astype(int)).std() + metrics['bayer_path'] = bayer_path + metrics['gt_path'] = gt_path + return gt_image, demosaiced_noisy, aligned, metrics \ No newline at end of file diff --git a/src/training/image_utils.py b/src/training/image_utils.py new file mode 100644 index 0000000..337f072 --- /dev/null +++ b/src/training/image_utils.py @@ -0,0 +1,289 @@ +import numpy as np +from PIL import Image +from scipy.ndimage import convolve + +def simulate_sparse(image, pattern="RGGB", cfa_type="bayer"): + """ + Simulate a sparse CFA (Color Filter Array) from an RGB image. + + Args: + image: numpy array (3, H, W), RGB image in [0, 1] or [0, 255]. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer, + or ignored if cfa_type="xtrans". + cfa_type: "bayer" or "xtrans". + + Returns: + cfa: numpy array (r, H, W), sparse CFA image. + sparse_mask: numpy array (r, H, W), mask of pixels. + """ + _, H, W= image.shape + cfa = np.zeros((3, H, W), dtype=image.dtype) + sparse_mask = np.zeros((3, H, W), dtype=image.dtype) + if cfa_type == "bayer": + # 2×2 Bayer masks + masks = { + "RGGB": np.array([["R", "G"], ["G", "B"]]), + "BGGR": np.array([["B", "G"], ["G", "R"]]), + "GRBG": np.array([["G", "R"], ["B", "G"]]), + "GBRG": np.array([["G", "B"], ["R", "G"]]), + } + if pattern not in masks: + raise ValueError(f"Unknown Bayer pattern: {pattern}") + + mask = masks[pattern] + cmap = {"R": 0, "G": 1, "B": 2} + + for i in range(2): + for j in range(2): + ch = cmap[mask[i, j]] + cfa[ch, i::2, j::2] = image[ch, i::2, j::2] + sparse_mask[ch, i::2, j::2] = 1 + elif cfa_type == "xtrans": + # Fuji X-Trans 6×6 repeating pattern + xtrans_pattern = np.array([ + ["G","B","R","G","R","B"], + ["R","G","G","B","G","G"], + ["B","G","G","R","G","G"], + ["G","R","B","G","B","R"], + ["B","G","G","R","G","G"], + ["R","G","G","B","G","G"], + ]) + cmap = {"R":0, "G":1, "B":2} + + for i in range(6): + for j in range(6): + ch = cmap[xtrans_pattern[i, j]] + cfa[ch, i::6, j::6] = image[ch, i::6, j::6] + sparse_mask[ch, i::2, j::2] = 1 + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + return cfa, sparse_mask + + +def color_jitter_0_1(img, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1): + """ + Applies color jitter to a NumPy image array with pixel values scaled 0-1. + + Args: + img (np.ndarray): Input image as a NumPy array (H, W, 3) with values in [0, 1]. + brightness (float): Max variation for brightness factor. + contrast (float): Max variation for contrast factor. + saturation (float): Max variation for saturation factor. + hue (float): Max variation for hue factor. + + Returns: + np.ndarray: The jittered image array with values clipped to [0, 1]. + """ + # Ensure the image is a float type for calculations + img = img.astype(np.float32) + + # 1. Adjust brightness + if brightness > 0: + b_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) + img = img * b_factor + + # 2. Adjust contrast + if contrast > 0: + c_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) + img = img * c_factor + (1 - c_factor) * np.mean(img, axis=(0, 1)) + + # Convert to HSV to adjust saturation and hue. + # PIL expects uint8, so we scale from [0, 1] to [0, 255] + img_pil = Image.fromarray(np.clip(img * 255, 0, 255).astype(np.uint8)) + img_hsv = np.array(img_pil.convert('HSV'), dtype=np.float32) + + # 3. Adjust saturation + if saturation > 0: + s_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) + img_hsv[:, :, 1] *= s_factor + + # 4. Adjust hue + if hue > 0: + h_factor = np.random.uniform(-hue, hue) + # Hue in Pillow is 0-255, so we scale our factor + img_hsv[:, :, 0] += h_factor * 255.0 + + # Clip HSV values to [0, 255] range + img_hsv = np.clip(img_hsv, 0, 255) + + # Convert back to RGB and scale back to [0, 1] + img_jittered_pil = Image.fromarray(img_hsv.astype(np.uint8), 'HSV').convert('RGB') + jittered_img = np.array(img_jittered_pil, dtype=np.float32) / 255.0 + + # Final clipping to ensure output is in [0, 1] + final_img = np.clip(jittered_img, 0, 1) + + return final_img + + + + +def bilinear_demosaic(sparse, pattern="RGGB", cfa_type="bayer"): + """ + Simple bilinear demosaicing for Bayer and X-Trans CFAs. + + Args: + sparse: numpy array (3, H, W) with a sparse representation of the bayer image. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer. + cfa_type: "bayer" or "xtrans". + + Returns: + rgb: numpy array (H, W, 3), bilinearly demosaiced image. + """ + C, H, W = sparse.shape + rgb = np.zeros((C, H, W), dtype=sparse.dtype) + + if cfa_type == "bayer": + # Bilinear interpolation kernel + kernels = [ + np.array([[.25, .5, .25], [.5, 1, .5], [.25, .5, .25]], dtype=np.float32), + np.array([[0, .25, 0], [.25, 1, .25], [0, .25, 0]], dtype=np.float32), + np.array([[.25, .5, .25], [.5, 1, .5], [.25, .5, .25]], dtype=np.float32) + + ] + + + elif cfa_type == "xtrans": + # Bilinear interpolation kernel + kernel = np.array([ + [1, 2, 4, 2, 1], + [2, 4, 8, 4, 2], + [4, 8, 16, 8, 4], + [2, 4, 8, 4, 2], + [1, 2, 4, 2, 1],], dtype=np.float32) + kernel /= kernel.sum() + + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + # Interpolate each channel + for ch in range(3): + rgb[ch, ...] = convolve(sparse[ch], kernels[ch], mode="mirror") + + # Mulitply by 2 or 4 depending on how sparse the layer is. + # rgb[0,...] *= 4.0 # red + # rgb[1,...] *= 2.0 # green (already dense) + # rgb[2,...] *= 4.0 # blue + return rgb + + + + + +def bilinear_demosaic_torch(sparse, pattern="RGGB", cfa_type="bayer"): + """ + Simple bilinear demosaicing for Bayer and X-Trans CFAs (Torch version). + + Args: + sparse: torch tensor (3, H, W) with a sparse representation of the CFA image. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer. + cfa_type: "bayer" or "xtrans". + + Returns: + rgb: torch tensor (3, H, W), bilinearly demosaiced image. + """ + C, H, W = sparse.shape + device = sparse.device + dtype = sparse.dtype + + if cfa_type == "bayer": + kernels = [ + torch.tensor([[.25, .5, .25], + [.5, 1, .5], + [.25, .5, .25]], dtype=dtype, device=device), + torch.tensor([[0, .25, 0], + [.25, 1, .25], + [0, .25, 0]], dtype=dtype, device=device), + torch.tensor([[.25, .5, .25], + [.5, 1, .5], + [.25, .5, .25]], dtype=dtype, device=device) + ] + elif cfa_type == "xtrans": + kernel = torch.tensor([ + [1, 2, 4, 2, 1], + [2, 4, 8, 4, 2], + [4, 8, 16, 8, 4], + [2, 4, 8, 4, 2], + [1, 2, 4, 2, 1]], dtype=dtype, device=device) + kernel /= kernel.sum() + kernels = [kernel, kernel, kernel] + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + rgb = torch.zeros_like(sparse) + + # Apply each kernel using conv2d + for ch in range(3): + k = kernels[ch].unsqueeze(0).unsqueeze(0) # (1,1,H,W) + x = sparse[ch:ch+1].unsqueeze(0) # (1,1,H,W) + pad_h, pad_w = k.shape[-2]//2, k.shape[-1]//2 + rgb[ch] = F.conv2d(F.pad(x, (pad_w, pad_w, pad_h, pad_h), mode="reflect"), k)[0,0] + + return rgb + +def cfa_to_sparse(image, pattern="RGGB", cfa_type="bayer"): + """ + Make a sparse representation from a CFA + + Args: + image: numpy array (H, W), RGB image in [0, 1] or [0, 255]. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer, + or ignored if cfa_type="xtrans". + cfa_type: "bayer" or "xtrans". + + Returns: + cfa: numpy array (r, H, W), sparse CFA image. + sparse_mask: numpy array (r, H, W), mask of pixels. + """ + H, W= image.shape + cfa = np.zeros((3, H, W), dtype=image.dtype) + sparse_mask = np.zeros((3, H, W), dtype=image.dtype) + if cfa_type == "bayer": + # 2×2 Bayer masks + masks = { + "RGGB": np.array([["R", "G"], ["G", "B"]]), + "BGGR": np.array([["B", "G"], ["G", "R"]]), + "GRBG": np.array([["G", "R"], ["B", "G"]]), + "GBRG": np.array([["G", "B"], ["R", "G"]]), + } + if pattern not in masks: + raise ValueError(f"Unknown Bayer pattern: {pattern}") + + mask = masks[pattern] + cmap = {"R": 0, "G": 1, "B": 2} + + for i in range(2): + for j in range(2): + ch = cmap[mask[i, j]] + cfa[ch, i::2, j::2] = image[i::2, j::2] + sparse_mask[ch, i::2, j::2] = 1 + elif cfa_type == "xtrans": + # Fuji X-Trans 6×6 repeating pattern + xtrans_pattern = np.array([ + ["G","B","R","G","R","B"], + ["R","G","G","B","G","G"], + ["B","G","G","R","G","G"], + ["G","R","B","G","B","R"], + ["B","G","G","R","G","G"], + ["R","G","G","B","G","G"], + ]) + cmap = {"R":0, "G":1, "B":2} + + for i in range(6): + for j in range(6): + ch = cmap[xtrans_pattern[i, j]] + cfa[ch, i::6, j::6] = image[i::6, j::6] + sparse_mask[ch, i::2, j::2] = 1 + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + return cfa, sparse_mask + + +def apply_gamma(x, gamma=2.2): + return x ** (1 / gamma) + +def inverse_gamma_tone_curve(img: np.ndarray, gamma: float = 2.2) -> np.ndarray: + img = np.clip(img, 0, 1) + return np.power(img, gamma) \ No newline at end of file diff --git a/src/training/rggb_loop.py b/src/training/rggb_loop.py new file mode 100644 index 0000000..51e3fea --- /dev/null +++ b/src/training/rggb_loop.py @@ -0,0 +1,120 @@ +from time import perf_counter +import time +from tqdm import tqdm +import torch +import torch.nn as nn +from src.training.utils import apply_gamma_torch + +def make_conditioning(conditioning, device): + B = conditioning.shape[0] + conditioning_extended = torch.zeros(B, 1).to(device) + conditioning_extended[:, 0] = conditioning[:, 0] + return conditioning_extended + + +def train_one_epoch_rggb(epoch, _model, _optimizer, _outname, _loader, _device, _loss_func, _clipping, log_interval = 10, sleep=0.2): + _model.train() + total_loss, n_images, total_final_image_loss = 0.0, 0, 0.0 + start = perf_counter() + pbar = tqdm(enumerate(_loader), total=len(_loader), desc=f"Train Epoch {epoch}") + + for batch_idx, (output) in pbar: + # for output in train_loader: + noisy = output['noisy'].float().to(_device) + conditioning = output['conditioning'].float().to(_device) + gt = output['aligned'].float().to(_device) + rggb = output['rggb'].float().to(_device) + + conditioning = make_conditioning(conditioning, _device) + with torch.autocast(device_type="mps", dtype=torch.bfloat16): + pred = _model(rggb, conditioning, noisy) + + loss = _loss_func(pred, gt) + _optimizer.zero_grad(set_to_none=True) + loss.backward() + torch.nn.utils.clip_grad_norm_(_model.parameters(), _clipping) + _optimizer.step() + + total_loss += float(loss.detach().cpu()) + n_images += gt.shape[0] + + # Testing final image quality + final_image_loss = nn.functional.l1_loss(pred, gt) + total_final_image_loss += final_image_loss.item() + del loss, pred, final_image_loss + torch.mps.empty_cache() + + if (batch_idx + 1) % log_interval == 0: + pbar.set_postfix({"loss": f"{total_loss/n_images:.4f}"}) + + time.sleep(sleep) + + torch.save(_model.state_dict(), _outname) + + print(f"[Epoch {epoch}] " + f"Train loss: {total_loss/n_images:.6f} " + f"Final image val loss: {total_final_image_loss/n_images:.6f} " + f"Time: {perf_counter()-start:.1f}s " + f"Images: {n_images}") + + return total_loss / max(1, n_images), perf_counter()-start + + + + +def visualize(idxs, _model, dataset, _device, _loss_func): + import matplotlib.pyplot as plt + _model.train() + total_loss, n_images, total_final_image_loss = 0.0, 0, 0.0 + start = perf_counter() + + + for idx in idxs: + # for output in train_loader: + row = dataset[idx] + noisy = row['noisy'].unsqueeze(0).float().to(_device) + conditioning = row['conditioning'].float().unsqueeze(0).to(_device) + gt = row['aligned'].unsqueeze(0).float().to(_device) + rggb = row['rggb'].unsqueeze(0).float().to(_device) + + conditioning = make_conditioning(conditioning, _device) + with torch.no_grad(): + with torch.autocast(device_type="mps", dtype=torch.bfloat16): + pred = _model(rggb, conditioning, noisy) + loss = _loss_func(pred, gt) + + + total_loss += float(loss.detach().cpu()) + n_images += gt.shape[0] + + # Testing final image quality + final_image_loss = nn.functional.l1_loss(pred, gt) + total_final_image_loss += final_image_loss.item() + + plt.subplots(2, 2, figsize=(15, 15)) + + plt.subplot(2, 2, 1) + pred = apply_gamma_torch(pred[0].cpu().permute(1, 2, 0)) + plt.imshow(pred) + + plt.subplot(2, 2, 2) + noisy = apply_gamma_torch(noisy[0].cpu().permute(1, 2, 0)) + plt.imshow(noisy) + + plt.subplot(2, 2, 3) + gt = apply_gamma_torch(gt[0].cpu().permute(1, 2, 0)) + plt.imshow(gt) + + plt.subplot(2, 2, 4) + plt.imshow(pred - gt + 0.5) + plt.show() + plt.clf() + + n_images = len(idxs) + print( + f"Train loss: {total_loss/n_images:.6f} " + f"Final image val loss: {total_final_image_loss/n_images:.6f} " + f"Time: {perf_counter()-start:.1f}s " + f"Images: {n_images}") + + return total_loss / max(1, n_images), perf_counter()-start \ No newline at end of file diff --git a/src/training/sparse_loop.py b/src/training/sparse_loop.py new file mode 100644 index 0000000..f64bba7 --- /dev/null +++ b/src/training/sparse_loop.py @@ -0,0 +1,136 @@ +from time import perf_counter +import time +from tqdm import tqdm +import torch +import torch.nn as nn +from src.training.utils import apply_gamma_torch + +def make_conditioning(conditioning, device): + B = conditioning.shape[0] + conditioning_extended = torch.zeros(B, 1).to(device) + conditioning_extended[:, 0] = conditioning[:, 0] + return conditioning_extended + + +def train_one_epoch(epoch, _model, _optimizer, _outname, _loader, _device, _loss_func, _clipping, log_interval = 10, sleep=0.2, rggb=False): + _model.train() + total_loss, n_images, total_final_image_loss = 0.0, 0, 0.0 + start = perf_counter() + pbar = tqdm(enumerate(_loader), total=len(_loader), desc=f"Train Epoch {epoch}") + + # scaler = torch.amp.GradScaler(device_type="mps") # create once, outside training loop + + for batch_idx, (output) in pbar: + # for output in train_loader: + noisy = output['noisy'].float().to(_device) + conditioning = output['conditioning'].float().to(_device) + gt = output['aligned'].float().to(_device) + input = output['sparse'].float().to(_device) + if rggb: + input = output['rggb'].float().to(_device) + + conditioning = make_conditioning(conditioning, _device) + # with torch.autocast(device_type="mps", dtype=torch.bfloat16): + + + _optimizer.zero_grad(set_to_none=True) + # with torch.autocast(device_type="mps", dtype=torch.float16): + pred = _model(input, conditioning, noisy) + + loss = _loss_func(pred, gt) + _optimizer.zero_grad(set_to_none=True) + loss.backward() + torch.nn.utils.clip_grad_norm_(_model.parameters(), _clipping) + _optimizer.step() + + total_loss += float(loss.detach().cpu()) + n_images += gt.shape[0] + + # Testing final image quality + final_image_loss = float(nn.functional.l1_loss(pred, gt).detach().cpu()) + total_final_image_loss += final_image_loss + del loss, pred, final_image_loss + torch.mps.empty_cache() + + if (batch_idx + 1) % log_interval == 0: + pbar.set_postfix({"loss": f"{total_loss/n_images:.4f}"}) + + time.sleep(sleep) + + torch.save(_model.state_dict(), _outname) + + print(f"[Epoch {epoch}] " + f"Train loss: {total_loss/n_images:.6f} " + f"Final image val loss: {total_final_image_loss/n_images:.6f} " + f"Time: {perf_counter()-start:.1f}s " + f"Images: {n_images}") + + return total_loss / max(1, n_images), perf_counter()-start + + + + +def visualize(idxs, _model, dataset, _device, _loss_func, rggb=False): + import matplotlib.pyplot as plt + _model.train() + total_loss, n_images, total_final_image_loss = 0.0, 0, 0.0 + start = perf_counter() + + + for idx in idxs: + # for output in train_loader: + row = dataset[idx] + noisy = row['noisy'].unsqueeze(0).float().to(_device) + conditioning = row['conditioning'].float().unsqueeze(0).to(_device) + gt = row['aligned'].unsqueeze(0).float().to(_device) + input = row['sparse'].unsqueeze(0).float().to(_device) + if rggb: + input = row['rggb'].unsqueeze(0).float().to(_device) + + conditioning = make_conditioning(conditioning, _device) + + with torch.no_grad(): + with torch.autocast(device_type="mps", dtype=torch.bfloat16): + pred = _model(input, conditioning, noisy) + loss = _loss_func(pred, gt) + + total_loss += float(loss.detach().cpu()) + n_images += gt.shape[0] + # Testing final image quality + final_image_loss = nn.functional.l1_loss(pred, gt) + total_final_image_loss += final_image_loss.item() + + plt.subplots(2, 3, figsize=(30, 15)) + + plt.subplot(2, 3, 1) + pred = apply_gamma_torch(pred[0].cpu().permute(1, 2, 0)) + plt.imshow(pred) + + plt.subplot(2, 3, 2) + noisy = apply_gamma_torch(noisy[0].cpu().permute(1, 2, 0)) + plt.imshow(noisy) + + plt.subplot(2, 3, 3) + gt = apply_gamma_torch(gt[0].cpu().permute(1, 2, 0)) + plt.imshow(gt) + + plt.subplot(2, 3, 4) + plt.imshow(pred - gt + 0.5) + + + plt.subplot(2, 3, 5) + plt.imshow(noisy - pred + 0.5) + + plt.subplot(2, 3, 6) + plt.imshow(noisy - gt + 0.5) + plt.show() + plt.clf() + + n_images = len(idxs) + print( + f"Train loss: {total_loss/n_images:.6f} " + f"Final image val loss: {total_final_image_loss/n_images:.6f} " + f"Time: {perf_counter()-start:.1f}s " + f"Images: {n_images}") + + return total_loss / max(1, n_images), perf_counter()-start \ No newline at end of file diff --git a/src/training/utils.py b/src/training/utils.py new file mode 100644 index 0000000..7ef09fb --- /dev/null +++ b/src/training/utils.py @@ -0,0 +1,295 @@ +import numpy as np +from PIL import Image +from scipy.ndimage import convolve + +def simulate_sparse(image, pattern="RGGB", cfa_type="bayer"): + """ + Simulate a sparse CFA (Color Filter Array) from an RGB image. + + Args: + image: numpy array (3, H, W), RGB image in [0, 1] or [0, 255]. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer, + or ignored if cfa_type="xtrans". + cfa_type: "bayer" or "xtrans". + + Returns: + cfa: numpy array (r, H, W), sparse CFA image. + sparse_mask: numpy array (r, H, W), mask of pixels. + """ + _, H, W= image.shape + cfa = np.zeros((3, H, W), dtype=image.dtype) + sparse_mask = np.zeros((3, H, W), dtype=image.dtype) + if cfa_type == "bayer": + # 2×2 Bayer masks + masks = { + "RGGB": np.array([["R", "G"], ["G", "B"]]), + "BGGR": np.array([["B", "G"], ["G", "R"]]), + "GRBG": np.array([["G", "R"], ["B", "G"]]), + "GBRG": np.array([["G", "B"], ["R", "G"]]), + } + if pattern not in masks: + raise ValueError(f"Unknown Bayer pattern: {pattern}") + + mask = masks[pattern] + cmap = {"R": 0, "G": 1, "B": 2} + + for i in range(2): + for j in range(2): + ch = cmap[mask[i, j]] + cfa[ch, i::2, j::2] = image[ch, i::2, j::2] + sparse_mask[ch, i::2, j::2] = 1 + elif cfa_type == "xtrans": + # Fuji X-Trans 6×6 repeating pattern + xtrans_pattern = np.array([ + ["G","B","R","G","R","B"], + ["R","G","G","B","G","G"], + ["B","G","G","R","G","G"], + ["G","R","B","G","B","R"], + ["B","G","G","R","G","G"], + ["R","G","G","B","G","G"], + ]) + cmap = {"R":0, "G":1, "B":2} + + for i in range(6): + for j in range(6): + ch = cmap[xtrans_pattern[i, j]] + cfa[ch, i::6, j::6] = image[ch, i::6, j::6] + sparse_mask[ch, i::2, j::2] = 1 + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + return cfa, sparse_mask + + +def color_jitter_0_1(img, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1): + """ + Applies color jitter to a NumPy image array with pixel values scaled 0-1. + + Args: + img (np.ndarray): Input image as a NumPy array (H, W, 3) with values in [0, 1]. + brightness (float): Max variation for brightness factor. + contrast (float): Max variation for contrast factor. + saturation (float): Max variation for saturation factor. + hue (float): Max variation for hue factor. + + Returns: + np.ndarray: The jittered image array with values clipped to [0, 1]. + """ + # Ensure the image is a float type for calculations + img = img.astype(np.float32) + + # 1. Adjust brightness + if brightness > 0: + b_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) + img = img * b_factor + + # 2. Adjust contrast + if contrast > 0: + c_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) + img = img * c_factor + (1 - c_factor) * np.mean(img, axis=(0, 1)) + + # Convert to HSV to adjust saturation and hue. + # PIL expects uint8, so we scale from [0, 1] to [0, 255] + img_pil = Image.fromarray(np.clip(img * 255, 0, 255).astype(np.uint8)) + img_hsv = np.array(img_pil.convert('HSV'), dtype=np.float32) + + # 3. Adjust saturation + if saturation > 0: + s_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) + img_hsv[:, :, 1] *= s_factor + + # 4. Adjust hue + if hue > 0: + h_factor = np.random.uniform(-hue, hue) + # Hue in Pillow is 0-255, so we scale our factor + img_hsv[:, :, 0] += h_factor * 255.0 + + # Clip HSV values to [0, 255] range + img_hsv = np.clip(img_hsv, 0, 255) + + # Convert back to RGB and scale back to [0, 1] + img_jittered_pil = Image.fromarray(img_hsv.astype(np.uint8), 'HSV').convert('RGB') + jittered_img = np.array(img_jittered_pil, dtype=np.float32) / 255.0 + + # Final clipping to ensure output is in [0, 1] + final_img = np.clip(jittered_img, 0, 1) + + return final_img + + + + +def bilinear_demosaic(sparse, pattern="RGGB", cfa_type="bayer"): + """ + Simple bilinear demosaicing for Bayer and X-Trans CFAs. + + Args: + sparse: numpy array (3, H, W) with a sparse representation of the bayer image. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer. + cfa_type: "bayer" or "xtrans". + + Returns: + rgb: numpy array (H, W, 3), bilinearly demosaiced image. + """ + C, H, W = sparse.shape + rgb = np.zeros((C, H, W), dtype=sparse.dtype) + + if cfa_type == "bayer": + # Bilinear interpolation kernel + kernels = [ + np.array([[.25, .5, .25], [.5, 1, .5], [.25, .5, .25]], dtype=np.float32), + np.array([[0, .25, 0], [.25, 1, .25], [0, .25, 0]], dtype=np.float32), + np.array([[.25, .5, .25], [.5, 1, .5], [.25, .5, .25]], dtype=np.float32) + + ] + + + elif cfa_type == "xtrans": + # Bilinear interpolation kernel + kernel = np.array([ + [1, 2, 4, 2, 1], + [2, 4, 8, 4, 2], + [4, 8, 16, 8, 4], + [2, 4, 8, 4, 2], + [1, 2, 4, 2, 1],], dtype=np.float32) + kernel /= kernel.sum() + + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + # Interpolate each channel + for ch in range(3): + rgb[ch, ...] = convolve(sparse[ch], kernels[ch], mode="mirror") + + # Mulitply by 2 or 4 depending on how sparse the layer is. + # rgb[0,...] *= 4.0 # red + # rgb[1,...] *= 2.0 # green (already dense) + # rgb[2,...] *= 4.0 # blue + return rgb + + + + + +def bilinear_demosaic_torch(sparse, pattern="RGGB", cfa_type="bayer"): + """ + Simple bilinear demosaicing for Bayer and X-Trans CFAs (Torch version). + + Args: + sparse: torch tensor (3, H, W) with a sparse representation of the CFA image. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer. + cfa_type: "bayer" or "xtrans". + + Returns: + rgb: torch tensor (3, H, W), bilinearly demosaiced image. + """ + C, H, W = sparse.shape + device = sparse.device + dtype = sparse.dtype + + if cfa_type == "bayer": + kernels = [ + torch.tensor([[.25, .5, .25], + [.5, 1, .5], + [.25, .5, .25]], dtype=dtype, device=device), + torch.tensor([[0, .25, 0], + [.25, 1, .25], + [0, .25, 0]], dtype=dtype, device=device), + torch.tensor([[.25, .5, .25], + [.5, 1, .5], + [.25, .5, .25]], dtype=dtype, device=device) + ] + elif cfa_type == "xtrans": + kernel = torch.tensor([ + [1, 2, 4, 2, 1], + [2, 4, 8, 4, 2], + [4, 8, 16, 8, 4], + [2, 4, 8, 4, 2], + [1, 2, 4, 2, 1]], dtype=dtype, device=device) + kernel /= kernel.sum() + kernels = [kernel, kernel, kernel] + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + rgb = torch.zeros_like(sparse) + + # Apply each kernel using conv2d + for ch in range(3): + k = kernels[ch].unsqueeze(0).unsqueeze(0) # (1,1,H,W) + x = sparse[ch:ch+1].unsqueeze(0) # (1,1,H,W) + pad_h, pad_w = k.shape[-2]//2, k.shape[-1]//2 + rgb[ch] = F.conv2d(F.pad(x, (pad_w, pad_w, pad_h, pad_h), mode="reflect"), k)[0,0] + + return rgb + +def cfa_to_sparse(image, pattern="RGGB", cfa_type="bayer"): + """ + Make a sparse representation from a CFA + + Args: + image: numpy array (H, W), RGB image in [0, 1] or [0, 255]. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer, + or ignored if cfa_type="xtrans". + cfa_type: "bayer" or "xtrans". + + Returns: + cfa: numpy array (r, H, W), sparse CFA image. + sparse_mask: numpy array (r, H, W), mask of pixels. + """ + H, W= image.shape + cfa = np.zeros((3, H, W), dtype=image.dtype) + sparse_mask = np.zeros((3, H, W), dtype=image.dtype) + if cfa_type == "bayer": + # 2×2 Bayer masks + masks = { + "RGGB": np.array([["R", "G"], ["G", "B"]]), + "BGGR": np.array([["B", "G"], ["G", "R"]]), + "GRBG": np.array([["G", "R"], ["B", "G"]]), + "GBRG": np.array([["G", "B"], ["R", "G"]]), + } + if pattern not in masks: + raise ValueError(f"Unknown Bayer pattern: {pattern}") + + mask = masks[pattern] + cmap = {"R": 0, "G": 1, "B": 2} + + for i in range(2): + for j in range(2): + ch = cmap[mask[i, j]] + cfa[ch, i::2, j::2] = image[i::2, j::2] + sparse_mask[ch, i::2, j::2] = 1 + elif cfa_type == "xtrans": + # Fuji X-Trans 6×6 repeating pattern + xtrans_pattern = np.array([ + ["G","B","R","G","R","B"], + ["R","G","G","B","G","G"], + ["B","G","G","R","G","G"], + ["G","R","B","G","B","R"], + ["B","G","G","R","G","G"], + ["R","G","G","B","G","G"], + ]) + cmap = {"R":0, "G":1, "B":2} + + for i in range(6): + for j in range(6): + ch = cmap[xtrans_pattern[i, j]] + cfa[ch, i::6, j::6] = image[i::6, j::6] + sparse_mask[ch, i::2, j::2] = 1 + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + return cfa, sparse_mask + +import torch + +def apply_gamma_torch(img: torch.tensor, gamma: float = 2.2) -> torch.tensor: + img = img.clamp(0, 1) + return (img ** (1.0 / gamma)).clamp(0, 1) + +def apply_gamma(img: np.ndarray, gamma: float = 2.2) -> np.ndarray: + img = np.clip(img, 0, 1) + return np.power(img, 1.0 / gamma) + +def inverse_gamma_tone_curve(img: np.ndarray, gamma: float = 2.2) -> np.ndarray: + img = np.clip(img, 0, 1) + return np.power(img, gamma) \ No newline at end of file From 631c41d0c253b50d47cdda36af1d4f3dd997ff10 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 16 Oct 2025 16:33:03 -0400 Subject: [PATCH 04/56] This branch contains training code for mps. Code to produce the small dataset is included --- 0_align_images.ipynb | 133 +++++++ 0_produce_small_dataset.ipynb | 669 ++++++++++++++++++++++++++++++++++ config.yaml | 6 + download_rawnind.sh | 43 +++ src/training/load_config.py | 14 + 5 files changed, 865 insertions(+) create mode 100644 0_align_images.ipynb create mode 100644 0_produce_small_dataset.ipynb create mode 100644 config.yaml create mode 100644 download_rawnind.sh create mode 100644 src/training/load_config.py diff --git a/0_align_images.ipynb b/0_align_images.ipynb new file mode 100644 index 0000000..125aeff --- /dev/null +++ b/0_align_images.ipynb @@ -0,0 +1,133 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "ac7bbf09", + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import numpy as np\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da5ede59", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.align_images import AlignImages\n", + "\n", + "from src.training.load_config import load_config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d50e4c7", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "path = Path(run_config['jpeg_output_subdir'])\n", + "outpath = Path(run_config['base_data_dir']) / run_config['jpeg_output_subdir']\n", + "secondary_align_csv = outpath / run_config['secondary_align_csv']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19e7da21", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = AlignImages(outpath, run_config['align_csv'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c549b157", + "metadata": {}, + "outputs": [], + "source": [ + "metric_list = []\n", + "for i in tqdm(range(len(dataset))):\n", + " gt, noisy, aligned, metrics = dataset[i]\n", + " metric_list.append(metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5acbb11e", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "df = pd.DataFrame(metric_list)\n", + "df.to_csv(secondary_align_csv)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ec8ee48", + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv(secondary_align_csv)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "192ded31", + "metadata": {}, + "outputs": [], + "source": [ + "(df.PSNR_after - df.PSNR_before).hist(bins = np.linspace(-1, 4, 100))\n", + "plt.ylabel('Images')\n", + "plt.xlabel('Delta PSNR')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c04b9532", + "metadata": {}, + "outputs": [], + "source": [ + "(df.SSIM_after - df.SSIM_before).hist(bins = np.linspace(-.01, .1, 100))\n", + "plt.ylabel('Images')\n", + "plt.xlabel('Delta SSIM')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/0_produce_small_dataset.ipynb b/0_produce_small_dataset.ipynb new file mode 100644 index 0000000..1c419a3 --- /dev/null +++ b/0_produce_small_dataset.ipynb @@ -0,0 +1,669 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "23c4096e", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import imageio\n", + "import cv2\n", + "import matplotlib.pyplot as plt\n", + "import os\n", + "from pathlib import Path\n", + "import re\n", + "from PIL import Image\n", + "import re\n", + "from collections import defaultdict\n", + "\n", + "from colour_demosaicing import (\n", + " ROOT_RESOURCES_EXAMPLES,\n", + " demosaicing_CFA_Bayer_bilinear,\n", + " demosaicing_CFA_Bayer_Malvar2004,\n", + " demosaicing_CFA_Bayer_Menon2007,\n", + " mosaicing_CFA_Bayer)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5b6231f2", + "metadata": {}, + "outputs": [], + "source": [ + "from RawHandler.RawHandler import RawHandler\n", + "from RawHandler.utils import linear_to_srgb\n", + "from src.training.load_config import load_config\n", + "\n", + "def apply_gamma(x, gamma=2.2):\n", + " return x ** (1 / gamma)\n", + "\n", + "def reverse_gamma(x, gamma=2.2):\n", + " return x ** gamma" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2e21592", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "raw_path = Path(run_config['base_data_dir'])\n", + "outpath = Path(run_config['jpeg_output_subdir'])\n", + "alignment_csv = outpath / run_config['align_csv']\n", + "outpath_cropped = run_config['cropped_jpeg_subdir']\n", + "\n", + "file_list = os.listdir(raw_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "783efea9", + "metadata": {}, + "outputs": [], + "source": [ + "def pair_images_by_scene(file_list, min_iso=100):\n", + " \"\"\"\n", + " Given a list of RAW image file paths:\n", + " 1. Extract ISO from filenames\n", + " 2. Remove files with ISO < min_iso\n", + " 3. Group by scene name\n", + " 4. Pair each image with the lowest-ISO version of the scene\n", + "\n", + " Args:\n", + " file_list (list of str): Paths to RAW files\n", + " min_iso (int): Minimum ISO to keep (default=100)\n", + "\n", + " Returns:\n", + " dict: {scene_name: [(img_path, gt_path), ...]}\n", + " \"\"\"\n", + " iso_pattern = re.compile(r\"_ISO(\\d+)_\")\n", + " scene_pairs = {}\n", + "\n", + " # Step 1: Extract iso and scene\n", + " images = []\n", + " for path in file_list:\n", + " filename = os.path.basename(path)\n", + " match = iso_pattern.search(filename)\n", + " if not match:\n", + " continue # skip if no ISO\n", + " iso = int(match.group(1))\n", + " if iso < min_iso:\n", + " continue # filter out low ISOs\n", + "\n", + " # Extract scene name:\n", + " if \"_GT_\" in filename:\n", + " scene = filename.split(\"_GT_\")[0]\n", + " else:\n", + " # Scene = part before \"_ISO\"\n", + " scene = filename.split(\"_ISO\")[0]\n", + " if 'X-Trans' in filename:\n", + " continue\n", + "\n", + " images.append((scene, iso, path))\n", + "\n", + " # Step 2: Group by scene\n", + " grouped = defaultdict(list)\n", + " for scene, iso, path in images:\n", + " grouped[scene].append((iso, path))\n", + "\n", + " # Step 3: For each scene, pick lowest ISO as GT\n", + " for scene, iso_paths in grouped.items():\n", + " iso_paths.sort(key=lambda x: x[0]) # sort by ISO ascending\n", + " gt_iso, gt_path = iso_paths[0] # lowest ISO ≥ min_iso\n", + " pairs = [(path, gt_path) for iso, path in iso_paths if path != gt_path]\n", + " scene_pairs[scene] = pairs\n", + "\n", + " return scene_pairs\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e7f13812", + "metadata": {}, + "outputs": [], + "source": [ + "def get_initial_warp_matrix(img1_gray, img2_gray, num_features=2000):\n", + " \"\"\"\n", + " Finds an initial warp matrix using ORB feature matching.\n", + "\n", + " Args:\n", + " img1_gray (np.array): The first grayscale image (template).\n", + " img2_gray (np.array): The second grayscale image (to be warped).\n", + " num_features (int): The number of features for ORB to detect.\n", + "\n", + " Returns:\n", + " np.array: The 2x3 Euclidean warp matrix, or the identity matrix if it fails.\n", + " \"\"\"\n", + " try:\n", + " # Initialize ORB detector\n", + " orb = cv2.ORB_create(nfeatures=num_features)\n", + "\n", + " # Find the keypoints and descriptors with ORB\n", + " keypoints1, descriptors1 = orb.detectAndCompute(img1_gray, None)\n", + " keypoints2, descriptors2 = orb.detectAndCompute(img2_gray, None)\n", + " \n", + " # Descriptors can be None if no keypoints are found\n", + " if descriptors1 is None or descriptors2 is None:\n", + " return np.eye(2, 3, dtype=np.float32)\n", + "\n", + " # Create BFMatcher object\n", + " # NORM_HAMMING is used for binary descriptors like ORB\n", + " bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)\n", + "\n", + " # Match descriptors\n", + " matches = bf.match(descriptors1, descriptors2)\n", + "\n", + " # Sort them in the order of their distance (best matches first)\n", + " matches = sorted(matches, key=lambda x: x.distance)\n", + "\n", + " # Keep only the top matches (e.g., top 50 or 15% of matches)\n", + " num_good_matches = min(len(matches), 50)\n", + " if num_good_matches < 10: # Need at least ~6-10 points for a robust estimate\n", + " return np.eye(2, 3, dtype=np.float32)\n", + " \n", + " matches = matches[:num_good_matches]\n", + "\n", + " # Extract location of good matches\n", + " points1 = np.zeros((len(matches), 2), dtype=np.float32)\n", + " points2 = np.zeros((len(matches), 2), dtype=np.float32)\n", + "\n", + " for i, match in enumerate(matches):\n", + " points1[i, :] = keypoints1[match.queryIdx].pt\n", + " points2[i, :] = keypoints2[match.trainIdx].pt\n", + "\n", + " # Find the rigid transformation (Euclidean) using RANSAC\n", + " # cv2.estimateAffinePartial2D is perfect for finding a Euclidean transform\n", + " warp_matrix, _ = cv2.estimateAffinePartial2D(points2, points1, method=cv2.RANSAC)\n", + " \n", + " # If estimation fails, it returns None\n", + " if warp_matrix is None:\n", + " return np.eye(2, 3, dtype=np.float32)\n", + "\n", + " return warp_matrix.astype(np.float32)\n", + "\n", + " except cv2.error as e:\n", + " print(f\"OpenCV error during feature matching: {e}\")\n", + " return np.eye(2, 3, dtype=np.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "7d9504cd", + "metadata": {}, + "outputs": [], + "source": [ + "def save_warp_dataframe(warp_matrix):\n", + " \"\"\"\n", + " Save warp matrix + metadata into a CSV with pandas.\n", + " warp_matrix: 2x3 or 3x3 numpy array\n", + " metadata: dict of other info\n", + " \"\"\"\n", + " flat = warp_matrix.flatten()\n", + " cols = [f\"m{i}{j}\" for i in range(warp_matrix.shape[0]) for j in range(warp_matrix.shape[1])]\n", + " row = dict(zip(cols, flat))\n", + " return row" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "419b9857", + "metadata": {}, + "outputs": [], + "source": [ + "def get_align_hybrid(noisy_fname, gt_fname, path, downsample_factor=4):\n", + " \"\"\"\n", + " Hybrid function to align images using feature-based pre-alignment (coarse)\n", + " and ECC (fine).\n", + " \"\"\"\n", + " # 1. Load raw files\n", + " noisy_handler = RawHandler(f'{path}/{noisy_fname}')\n", + " gt_handler = RawHandler(f'{path}/{gt_fname}')\n", + "\n", + " noisy_bayer = noisy_handler.apply_colorspace_transform(colorspace='lin_rec2020', clip=True).astype(np.float32)\n", + " gt_bayer = gt_handler.apply_colorspace_transform(colorspace='lin_rec2020', clip=True).astype(np.float32)\n", + " noisy_bayer = apply_gamma(noisy_bayer)\n", + " gt_bayer = apply_gamma(gt_bayer)\n", + " \n", + " noisy_image = demosaicing_CFA_Bayer_Malvar2004(noisy_bayer)\n", + " gt_image = demosaicing_CFA_Bayer_Malvar2004(gt_bayer)\n", + " noisy_image = np.clip(noisy_image, 0, 1)\n", + " gt_image = np.clip(gt_image, 0, 1)\n", + "\n", + "\n", + " # Note: OpenCV expects BGR order, RawHandler might give RGB. Ensure consistency.\n", + " # Assuming BGR for cvtColor. If RGB, use cv2.COLOR_RGB2GRAY.\n", + " gt_image_uint8 = (gt_image * 255.0).clip(0, 255).astype(np.uint8)\n", + " noisy_image_uint8 = (noisy_image * 255.0).clip(0, 255).astype(np.uint8)\n", + "\n", + " # 3. Convert to grayscale using the faster uint8 versions\n", + " noisy_gray = cv2.cvtColor(noisy_image_uint8, cv2.COLOR_BGR2GRAY)\n", + " gt_gray = cv2.cvtColor(gt_image_uint8, cv2.COLOR_BGR2GRAY)\n", + " h, w = noisy_gray.shape\n", + "\n", + " # 4. --- NEW: Get initial warp matrix from feature matching ---\n", + " # We run this on the full-res grayscale images for better keypoint detection\n", + " warp_matrix = get_initial_warp_matrix(gt_gray, noisy_gray)\n", + "\n", + " # 5. Downsample for ECC refinement\n", + " if downsample_factor > 1:\n", + " # We need to scale the initial guess to the downsampled size\n", + " warp_matrix_scaled = warp_matrix.copy()\n", + " warp_matrix_scaled[0, 2] /= downsample_factor\n", + " warp_matrix_scaled[1, 2] /= downsample_factor\n", + "\n", + " noisy_gray_small = cv2.resize(noisy_gray, (w // downsample_factor, h // downsample_factor), interpolation=cv2.INTER_AREA)\n", + " gt_gray_small = cv2.resize(gt_gray, (w // downsample_factor, h // downsample_factor), interpolation=cv2.INTER_AREA)\n", + " else:\n", + " noisy_gray_small = noisy_gray\n", + " gt_gray_small = gt_gray\n", + " warp_matrix_scaled = warp_matrix\n", + "\n", + " # 6. ECC alignment (refinement) using the improved initial guess\n", + " warp_mode = cv2.MOTION_EUCLIDEAN\n", + " criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 200, 1e-5)\n", + " # We provide `warp_matrix_scaled` as the initial guess!\n", + " try:\n", + "\n", + " (cc, warp_matrix_final_scaled) = cv2.findTransformECC(gt_gray_small, noisy_gray_small, warp_matrix_scaled, warp_mode, criteria)\n", + " except cv2.error:\n", + " # If ECC fails, use the initial matrix from feature matching\n", + " cc = -1.0 # Indicate failure or that we used the fallback\n", + " warp_matrix_final_scaled = warp_matrix_scaled\n", + " \n", + " # 7. Scale the final matrix translation back to full resolution\n", + " \n", + " warp_matrix_final = warp_matrix_final_scaled.copy()\n", + " if downsample_factor > 1:\n", + " warp_matrix_final[0, 2] *= downsample_factor\n", + " warp_matrix_final[1, 2] *= downsample_factor\n", + "\n", + " # 8. Warp the original full-resolution FLOAT image\n", + " gt_aligned = cv2.warpAffine(gt_image, warp_matrix_final, (w, h), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP)\n", + "\n", + " # 9. Metadata / stats\n", + " try:\n", + " iso = float(re.findall('ISO([0-9]+)', noisy_fname)[0])\n", + " except (IndexError, ValueError):\n", + " iso = 0\n", + "\n", + " info_dict = {\n", + " \"cc\": cc,\n", + " \"noisy_image\": noisy_fname,\n", + " \"gt_image\": gt_fname,\n", + " \"gt_mean\": gt_image.mean(),\n", + " \"noisy_mean\": noisy_image.mean(),\n", + " \"noise_level\": (noisy_image-gt_image).std(axis=(0,1)),\n", + " **save_warp_dataframe(warp_matrix_final),\n", + " \"iso\": iso,\n", + " }\n", + "\n", + " return info_dict, noisy_image, gt_aligned, noisy_bayer[0], gt_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9af686c6", + "metadata": {}, + "outputs": [], + "source": [ + "pair_file_list = pair_images_by_scene(file_list)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b1d38d1", + "metadata": {}, + "outputs": [], + "source": [ + "def as_8bit(x):\n", + " return (x * 255).astype(np.uint8)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ebb696eb", + "metadata": {}, + "outputs": [], + "source": [ + "# Test loop so we can visualize the alignment performance\n", + "\n", + "list = []\n", + "idx = 0\n", + "for key in pair_file_list.keys():\n", + " image_pairs = pair_file_list[key]\n", + " print(idx, idx/len(pair_file_list))\n", + " idx+=1\n", + " jdx = 0\n", + " for (noise, gt) in image_pairs:\n", + " output, noisy_image, gt_aligned, noisy_bayer, gt_image = get_align_hybrid(noise, gt, path, downsample_factor=1)\n", + " break\n", + " break\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11269156", + "metadata": {}, + "outputs": [], + "source": [ + "plt.subplots(2, 3, figsize=(30, 20))\n", + "\n", + "plt.subplot(2,3,1)\n", + "plt.imshow(noisy_image[1000:1100, 3000:3100])\n", + "\n", + "plt.subplot(2,3,2)\n", + "plt.imshow(noisy_bayer[1000:1100, 3000:3100])\n", + "\n", + "plt.subplot(2,3, 3)\n", + "plt.imshow(gt_image[1000:1100, 3000:3100])\n", + "\n", + "\n", + "plt.subplot(2,3, 4)\n", + "plt.imshow(noisy_image[1000:1100, 3000:3100]-gt_image[1000:1100, 3000:3100]+0.5)\n", + "\n", + "\n", + "plt.subplot(2,3, 5)\n", + "plt.imshow(noisy_image[1000:1100, 3000:3100]-gt_aligned[1000:1100, 3000:3100]+0.5)\n", + "\n", + "plt.subplot(2,3, 6)\n", + "# plt.imshow(noisy_bayer[1000:1100, 3000:3100]-noisy_image[1000:1100, 3000:3100]+0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1e1066b", + "metadata": {}, + "outputs": [], + "source": [ + "# Align and save the jpegs\n", + "list = []\n", + "idx = 0\n", + "for key in pair_file_list.keys():\n", + " image_pairs = pair_file_list[key]\n", + " print(idx, idx/len(pair_file_list))\n", + " idx+=1\n", + " jdx = 0\n", + " for (noise, gt) in image_pairs:\n", + " try:\n", + " output, noisy_image, gt_aligned, noisy_bayer, gt_image = get_align_hybrid(noise, gt, path, downsample_factor=1)\n", + " list.append(output)\n", + " noisy_image\n", + " imageio.imwrite(f\"{outpath}/{noise}.jpg\", as_8bit(noisy_image), quality=100)\n", + " imageio.imwrite(f\"{outpath}/{noise}_bayer.jpg\", as_8bit(noisy_bayer), quality=100)\n", + " if jdx==0:\n", + " imageio.imwrite(f\"{outpath}/{gt}.jpg\", as_8bit(gt_image), quality=100)\n", + " jdx+=1\n", + " except:\n", + " print(f\"skipping {noise}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79bab144", + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.DataFrame(list)\n", + "df.to_csv(alignment_csv)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "337da98d", + "metadata": {}, + "outputs": [], + "source": [ + "###\n", + "## The following code allows for the existing dataset to be further reduced by cropping into center squares\n", + "###" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca7517af", + "metadata": {}, + "outputs": [], + "source": [ + "# Load saved dataset and visualize alignment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3452667b", + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv(alignment_csv)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9eb194bc", + "metadata": {}, + "outputs": [], + "source": [ + "row = df.iloc[-2]\n", + "\n", + "# Get Row Matrix\n", + "shape=(2,3)\n", + "cols = [f\"m{i}{j}\" for i in range(shape[0]) for j in range(shape[1])]\n", + "flat = np.array([row.pop(c) for c in cols], dtype=np.float32)\n", + "warp_matrix = flat.reshape(shape)\n", + "warp_matrix\n", + "\n", + "\n", + "noisy_name = row.noisy_image\n", + "gt_name = row.gt_image\n", + "\n", + "with imageio.imopen(f\"{outpath}/{noisy_name}_bayer.jpg\", \"r\") as image_resource:\n", + " bayer_data = image_resource.read()\n", + "\n", + "with imageio.imopen(f\"{outpath}/{noisy_name}.jpg\", \"r\") as image_resource:\n", + " noisy = image_resource.read()\n", + "\n", + "\n", + "with imageio.imopen(f\"{outpath}/{gt_name}.jpg\", \"r\") as image_resource:\n", + " gt_image = image_resource.read()\n", + "\n", + "noisy = noisy/255\n", + "gt_image = gt_image/255\n", + "bayer_data = bayer_data/255\n", + "h, w, _ = noisy.shape\n", + "gt = cv2.warpAffine(gt_image, warp_matrix, (w, h), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP)\n", + "\n", + "warp_matrix\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "621a322b", + "metadata": {}, + "outputs": [], + "source": [ + "demosaiced = demosaicing_CFA_Bayer_Malvar2004(bayer_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f8dbf1f", + "metadata": {}, + "outputs": [], + "source": [ + "plt.subplots(2, 3, figsize=(30, 20))\n", + "\n", + "plt.subplot(2,3,1)\n", + "plt.imshow(noisy[1000:1100, 3000:3100])\n", + "\n", + "plt.subplot(2,3,2)\n", + "plt.imshow(demosaiced[1000:1100, 3000:3100])\n", + "\n", + "plt.subplot(2,3, 3)\n", + "plt.imshow(gt[1000:1100, 3000:3100])\n", + "\n", + "\n", + "plt.subplot(2,3, 4)\n", + "plt.imshow(gt[1000:1100, 3000:3100]-demosaiced[1000:1100, 3000:3100]+0.5)\n", + "\n", + "\n", + "plt.subplot(2,3, 5)\n", + "plt.imshow(noisy[1000:1100, 3000:3100]-demosaiced[1000:1100, 3000:3100]+0.5) \n", + "\n", + "plt.subplot(2,3, 6)\n", + "plt.imshow(noisy[1000:1100, 3000:3100]-gt_image[1000:1100, 3000:3100]+0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a37b64b4", + "metadata": {}, + "outputs": [], + "source": [ + "##\n", + "## Crop to center to save even more space\n", + "##" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e3c61f1", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def crop_center_square(input_dir, output_dir, crop_size):\n", + " \"\"\"\n", + " Loops over image files in a directory, crops a center square, and saves them.\n", + "\n", + " Args:\n", + " input_dir (str): The path to the directory containing the original images.\n", + " output_dir (str): The path to the directory where cropped images will be saved.\n", + " crop_size (int): The width and height of the square to crop (must be an even number).\n", + " \"\"\"\n", + " # --- 1. Input Validation ---\n", + " if not os.path.isdir(input_dir):\n", + " print(f\"Error: Input directory not found at '{input_dir}'\")\n", + " return\n", + "\n", + " if crop_size % 2 != 0:\n", + " print(f\"Error: Crop size must be an even number. You provided {crop_size}.\")\n", + " return\n", + " \n", + " # --- 2. Create Output Directory ---\n", + " # Create the output directory if it doesn't already exist.\n", + " os.makedirs(output_dir, exist_ok=True)\n", + " print(f\"Output will be saved to: '{output_dir}'\")\n", + "\n", + " # --- 3. Process Images ---\n", + " # Get a list of all files in the input directory.\n", + " files = os.listdir(input_dir)\n", + "\n", + " processed_count = 0\n", + " for filename in files:\n", + " # Construct the full path for the input file.\n", + " input_path = os.path.join(input_dir, filename)\n", + "\n", + " # Process only files, not subdirectories.\n", + " if os.path.isfile(input_path):\n", + " try:\n", + " # Open the image using Pillow.\n", + " with Image.open(input_path) as img:\n", + " width, height = img.size\n", + "\n", + " # Check if the image is large enough to be cropped.\n", + " if width < crop_size or height < crop_size:\n", + " print(f\"Skipping '{filename}': smaller than crop size.\")\n", + " continue\n", + " \n", + " # Calculate the coordinates for the center crop.\n", + " left = (width - crop_size) // 2\n", + " top = (height - crop_size) // 2\n", + "\n", + " # Ensure the top-left corner is on an even pixel coordinate for Bayer alignment.\n", + " if left % 2 != 0:\n", + " left -= 1\n", + " if top % 2 != 0:\n", + " top -= 1\n", + " \n", + " # Calculate the bottom-right corner based on the adjusted top-left corner.\n", + " # Since crop_size is even, right and bottom will also be even.\n", + " right = left + crop_size\n", + " bottom = top + crop_size\n", + "\n", + " # Perform the crop. The box is a 4-tuple defining the left, upper, right, and lower pixel coordinate.\n", + " img_cropped = np.array(img.crop((left, top, right, bottom)))\n", + "\n", + " # Construct the full path for the output file.\n", + " output_path = os.path.join(output_dir, filename)\n", + " \n", + " # Save the cropped image.\n", + " # img_cropped.save(output_path)\n", + " imageio.imwrite(output_path,img_cropped, quality=95)\n", + " processed_count += 1\n", + "\n", + " except (IOError, OSError) as e:\n", + " # Handle cases where the file is not a valid image.\n", + " print(f\"Could not process '{filename}'. It might not be an image file. Error: {e}\")\n", + " except Exception as e:\n", + " print(f\"An unexpected error occurred with file '{filename}': {e}\")\n", + " \n", + " print(f\"\\nProcessing complete. Cropped {processed_count} images.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9e4cf92", + "metadata": {}, + "outputs": [], + "source": [ + "crop_center_square(outpath, outpath_cropped, crop_size=2000)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..870b70a --- /dev/null +++ b/config.yaml @@ -0,0 +1,6 @@ +# config.yaml +base_data_dir: /Volumes/EasyStore/RAWNIND/ +jpeg_output_subdir: /Users/ryanmueller/Pictures/JPEGs_high_quality +cropped_jpeg_subdir: /Users/ryanmueller/Pictures/Cropped_JPEG_high_quality +align_csv: align_data_high_quality.csv +secondary_align_csv: secondary_align.csv \ No newline at end of file diff --git a/download_rawnind.sh b/download_rawnind.sh new file mode 100644 index 0000000..c76b91f --- /dev/null +++ b/download_rawnind.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# Dataverse base URL +BASE_URL="https://dataverse.uclouvain.be" + +# Dataset DOI +DOI="doi:10.14428/DVN/DEQCIM" + +# Metadata file +METADATA_FILE="dataset.json" + +# Fetch dataset metadata if not already downloaded +if [ ! -f "$METADATA_FILE" ]; then + echo "Fetching dataset metadata..." + curl -s "$BASE_URL/api/datasets/:persistentId?persistentId=$DOI" -o "$METADATA_FILE" +else + echo "Using cached metadata file: $METADATA_FILE" +fi + +# Parse and download each file +echo "Starting download..." +jq -c '.data.latestVersion.files[]' "$METADATA_FILE" | while read file; do + FILE_ID=$(echo "$file" | jq '.dataFile.id') + FILENAME=$(echo "$file" | jq -r '.dataFile.filename') + + # Skip file if it already exists + if [ -f "$FILENAME" ]; then + echo "Skipping existing file: $FILENAME" + continue + fi + + echo "Downloading: $FILENAME (ID: $FILE_ID)" + curl -L -C - -o "$FILENAME" "$BASE_URL/api/access/datafile/$FILE_ID" + + # Check if download was successful + if [ $? -eq 0 ]; then + echo "Finished: $FILENAME" + else + echo "Error downloading: $FILENAME — will retry next time" + fi +done + +echo "All files processed." diff --git a/src/training/load_config.py b/src/training/load_config.py new file mode 100644 index 0000000..39b1bc4 --- /dev/null +++ b/src/training/load_config.py @@ -0,0 +1,14 @@ +import yaml + +def load_config(file_path='config.yaml'): + """Loads configuration data from a YAML file.""" + try: + with open(file_path, 'r') as f: + config = yaml.safe_load(f) + return config + except FileNotFoundError: + print(f"Error: Configuration file not found at {file_path}") + return None + except yaml.YAMLError as exc: + print(f"Error parsing YAML file: {exc}") + return None \ No newline at end of file From 75dfdc1c3db32e994aeb7bb9c3ccc6a9b55a3057 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 16 Oct 2025 18:42:22 -0400 Subject: [PATCH 05/56] Pretraining code --- 1_pretrain_model.ipynb | 156 +++++++-------- 1_validate_model.ipynb | 179 ++++++++++++++++++ src/Restorer/Cond_CHASPA.py | 128 ++----------- src/Restorer/Cond_NAF.py | 120 +----------- src/training/SmallRawDataset.py | 4 +- src/training/align_images.py | 12 +- .../losses/MultiScaleLoss.py | 0 src/training/{ => losses}/ShadowAwareLoss.py | 0 src/{Restorer => training}/losses/__init__.py | 0 .../{sparse_loop.py => train_loop.py} | 29 ++- 10 files changed, 284 insertions(+), 344 deletions(-) create mode 100644 1_validate_model.ipynb rename src/{Restorer => training}/losses/MultiScaleLoss.py (100%) rename src/training/{ => losses}/ShadowAwareLoss.py (100%) rename src/{Restorer => training}/losses/__init__.py (100%) rename src/training/{sparse_loop.py => train_loop.py} (84%) diff --git a/1_pretrain_model.ipynb b/1_pretrain_model.ipynb index 1394471..a737238 100644 --- a/1_pretrain_model.ipynb +++ b/1_pretrain_model.ipynb @@ -12,7 +12,10 @@ "from torch.utils.data import DataLoader, random_split\n", "import torch.nn as nn\n", "import torch\n", - "import copy" + "import copy\n", + "import mlflow\n", + "import mlflow.pytorch\n", + "from pathlib import Path" ] }, { @@ -23,87 +26,77 @@ "outputs": [], "source": [ "from src.training.SmallRawDataset import SmallRawDataset\n", - "# from src.Restorer.Cond_NAF import make_model, make_full_model, make_full_model_RGGB\n", - "\n", - "from src.training.ShadowAwareLoss import ShadowAwareLoss\n", + "from src.training.losses.ShadowAwareLoss import ShadowAwareLoss\n", "from src.training.VGGFeatureExtractor import VGGFeatureExtractor\n", - "from src.training.sparse_loop import train_one_epoch, visualize\n", - "from src.training.rggb_loop import train_one_epoch_rggb, visualize\n", - "\n", - "\n", - "from src.training.utils import apply_gamma_torch\n" + "from src.training.train_loop import train_one_epoch, visualize\n", + "from src.training.utils import apply_gamma_torch\n", + "from src.training.load_config import load_config\n", + "from src.Restorer.Cond_NAF import make_full_model_RGGB\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "ba20b866", + "id": "a0464c98", "metadata": {}, "outputs": [], "source": [ - "device= 'mps'\n", - "\n", - "batch_size = 2\n", - "lr = 1e-4 * batch_size / 4\n", - "# lr = 1e-3 * batch_size / 32\n", - "clipping = 1e-2\n", - "\n", - "num_epochs = 75\n", - "val_split = 0.2" + "run_config = load_config()\n", + "dataset_path = Path(run_config['jpeg_output_subdir'])\n", + "align_csv = dataset_path / run_config['align_csv']" ] }, { "cell_type": "code", "execution_count": null, - "id": "15f16fa7", + "id": "ba20b866", "metadata": {}, "outputs": [], "source": [ - "dataset = SmallRawDataset('/Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG_high_quality/', 'align.csv', crop_size=256)\n", + "device=run_config['device']\n", "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", "\n", - "# Split dataset into train and val\n", - "val_size = int(len(dataset) * val_split)\n", - "train_size = len(dataset) - val_size\n", - "torch.manual_seed(42) # For reproducibility\n", - "train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n", - "# Set the validation dataset to use the same crops\n", - "val_dataset = copy.deepcopy(val_dataset)\n", - "val_dataset.dataset.validation = True\n", - "\n", - "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n", - "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)" + "num_epochs = run_config['num_epochs_pretraining']\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "mlflow.set_experiment(experiment)" ] }, { "cell_type": "code", "execution_count": null, - "id": "3a219ab5", + "id": "a36eaef2", "metadata": {}, "outputs": [], "source": [ - "# model_name = '/Volumes/EasyStore/models/Cond_NAF_original_null.pt'\n", - "# model = make_full_model_RGGB(model_name=None)\n", - "# model = model.to(device)\n", - "\n" + "model_params = run_config['model_params']\n", + "rggb = model_params['rggb']" ] }, { "cell_type": "code", "execution_count": null, - "id": "7669fd6d", + "id": "15f16fa7", "metadata": {}, "outputs": [], "source": [ - "# # model_name = '/Volumes/EasyStore/models/Restorer_train_vgg_relu_300_full.pt'\n", - "# # model = make_full_model(model_name=model_name)\n", - "# # model = model.to(device)\n", + "dataset = SmallRawDataset(dataset_path, align_csv, crop_size=crop_size)\n", "\n", - "# model_name = '/Volumes/EasyStore/models/Restorer_train_vgg_relu_300_full_RGGB.pt'\n", - "# model_name = '/Volumes/EasyStore/models/Restorer_train_vgg_relu_0_full_RGGB.pt'\n", + "# Split dataset into train and val\n", + "val_size = int(len(dataset) * val_split)\n", + "train_size = len(dataset) - val_size\n", + "torch.manual_seed(42) # For reproducibility\n", + "train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n", + "# Set the validation dataset to use the same crops\n", + "val_dataset = copy.deepcopy(val_dataset)\n", + "val_dataset.dataset.validation = True\n", "\n", - "# model = make_full_model_RGGB(model_name=model_name)\n", - "# model = model.to(device)\n" + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)" ] }, { @@ -113,44 +106,10 @@ "metadata": {}, "outputs": [], "source": [ - "from src.Restorer.Cond_NAF import make_model, make_full_model, make_full_model_RGGB\n", - "\n", - "model_name = '/Volumes/EasyStore/models/Restorer_train_vgg_relu_300_full_RGGB.pt'\n", - "model_name = '/Volumes/EasyStore/models/Restorer_train_vgg_relu_0_full_RGGB.pt'\n", - "\n", - "model = make_full_model_RGGB(model_name=model_name)\n", + "model = make_full_model_RGGB(model_params, model_name=None)\n", "model = model.to(device)\n", "\n", - "\n", - "from src.Restorer.Cond_CHASPA import make_full_model_RGGB\n", - "\n", - "\n", - "model_name = '/Volumes/EasyStore/models/Restorer_train_vgg_relu_0_full_RGGB_CHASPA.pt'\n", - "\n", - "model = make_full_model_RGGB(model_name=model_name)\n", - "model = model.to(device)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8012a124", - "metadata": {}, - "outputs": [], - "source": [ - "# model_name = '/Volumes/EasyStore/models/Restorer_train_vgg_relu_300_full.pt'\n", - "# model = make_full_model(model_name=model_name)\n", - "# # model = model.to(device)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1a666123", - "metadata": {}, - "outputs": [], - "source": [ - "model" + "params = {**run_config, **model_params}" ] }, { @@ -169,12 +128,12 @@ "vfe = vfe.to(device)\n", "\n", "loss_fn = ShadowAwareLoss(\n", - " alpha=0.2,\n", - " beta=5.0,\n", - " l1_weight=0.16,\n", - " ssim_weight=0.84,\n", - " tv_weight=0.0,\n", - " vgg_loss_weight=0,\n", + " alpha=run_config['alpha'],\n", + " beta=run_config['beta'],\n", + " l1_weight=run_config['l1_weight'],\n", + " ssim_weight=run_config['ssim_weight'],\n", + " tv_weight=run_config['tv_weight'],\n", + " vgg_loss_weight=run_config['vgg_loss_weight'],\n", " apply_gamma_fn=apply_gamma_torch,\n", " vgg_feature_extractor=vfe,\n", " device=device,\n", @@ -188,8 +147,27 @@ "metadata": {}, "outputs": [], "source": [ - "for epoch in range(num_epochs):\n", - " train_one_epoch(epoch, model, optimizer, model_name, train_loader, device, loss_fn, clipping, log_interval = 10, sleep=0.0, rggb=True)" + "with mlflow.start_run(run_name=run_config['run_name']) as run:\n", + "\n", + " mlflow.log_params(params)\n", + " for epoch in range(num_epochs):\n", + " train_one_epoch(epoch, model, optimizer, train_loader, device, loss_fn, clipping, \n", + " log_interval = 10, sleep=0.0, rggb=rggb, max_batches=2)\n", + " \n", + " mlflow.pytorch.log_model(\n", + " pytorch_model=model,\n", + " name=run_config['run_path'],\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a775d31", + "metadata": {}, + "outputs": [], + "source": [ + "run.info.run_id" ] } ], diff --git a/1_validate_model.ipynb b/1_validate_model.ipynb new file mode 100644 index 0000000..a909f15 --- /dev/null +++ b/1_validate_model.ipynb @@ -0,0 +1,179 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6351e77", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader, random_split\n", + "import torch.nn as nn\n", + "import torch\n", + "import copy\n", + "import mlflow\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2043dc8e", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.SmallRawDataset import SmallRawDataset\n", + "\n", + "from src.training.losses.ShadowAwareLoss import ShadowAwareLoss\n", + "from src.training.VGGFeatureExtractor import VGGFeatureExtractor\n", + "from src.training.utils import apply_gamma_torch\n", + "from src.training.train_loop import visualize\n", + "from src.training.load_config import load_config\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b14af214", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "dataset_path = Path(run_config['jpeg_output_subdir'])\n", + "align_csv = dataset_path / run_config['align_csv']\n", + "device=run_config['device']\n", + "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", + "\n", + "num_epochs = run_config['num_epochs_pretraining']\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "model_params = run_config['model_params']\n", + "rggb = model_params['rggb']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba20b866", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "RUN_ID = \"66f11f4639f24e9ea75b4c953147be15\" \n", + "ARTIFACT_PATH = run_config['run_path']\n", + "\n", + "model_uri = f\"runs:/{RUN_ID}/{ARTIFACT_PATH}\"\n", + "\n", + "try:\n", + " model = mlflow.pytorch.load_model(model_uri)\n", + " model.eval()\n", + " print(f\"Model successfully loaded from MLflow URI: {model_uri}\")\n", + " \n", + "\n", + "except Exception as e:\n", + " print(f\"Error loading model from MLflow: {e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f16fa7", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = SmallRawDataset(dataset_path, align_csv, crop_size=crop_size)\n", + "\n", + "# Split dataset into train and val\n", + "val_size = int(len(dataset) * val_split)\n", + "train_size = len(dataset) - val_size\n", + "torch.manual_seed(42) # For reproducibility\n", + "train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n", + "# Set the validation dataset to use the same crops\n", + "val_dataset = copy.deepcopy(val_dataset)\n", + "val_dataset.dataset.validation = True\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6af0f3a2", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", + "\n", + "vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), \n", + " feature_layers=[14], \n", + " activation=nn.ReLU\n", + " )\n", + "vfe = vfe.to(device)\n", + "\n", + "loss_fn = ShadowAwareLoss(\n", + " alpha=run_config['alpha'],\n", + " beta=run_config['beta'],\n", + " l1_weight=run_config['l1_weight'],\n", + " ssim_weight=run_config['ssim_weight'],\n", + " tv_weight=run_config['tv_weight'],\n", + " vgg_loss_weight=run_config['vgg_loss_weight'],\n", + " apply_gamma_fn=apply_gamma_torch,\n", + " vgg_feature_extractor=vfe,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae01182c", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "subset_indices = np.array(val_dataset.indices) # indices in the original dataset\n", + "mask = val_dataset.dataset.df.iso.values[subset_indices] == 65535\n", + "matching_indices_in_subset = np.nonzero(mask)[0]\n", + "matching_indices_in_subset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2234a52d", + "metadata": {}, + "outputs": [], + "source": [ + "visualize(matching_indices_in_subset, model, val_dataset, device, loss_fn, rggb=rggb)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/Restorer/Cond_CHASPA.py b/src/Restorer/Cond_CHASPA.py index 9371a48..f7bac68 100644 --- a/src/Restorer/Cond_CHASPA.py +++ b/src/Restorer/Cond_CHASPA.py @@ -324,69 +324,11 @@ def check_image_size(self, x): x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) return x - -class ModelWrapper(nn.Module): - def __init__(self): - super().__init__() - self.model = Restorer( - chans = [32, 128, 256, 512, 1024], - enc_blk_nums = [1,1,3,4], - middle_blk_num = 6, - dec_blk_nums = [2, 2, 1, 1], - cond_input = 1, - in_channels = 3, - out_channels = 3, - ) - - def forward(self, x, cond, residual): - output = self.model(x, cond) - return residual + output - -def make_model(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw.pt'): - model = ModelWrapper() - if not model_name is None: - state_dict = torch.load(model_name, map_location="cpu") - model.load_state_dict(state_dict) - return model - -class ModelWrapperFull(nn.Module): - def __init__(self): - super().__init__() - self.model = Restorer( - chans = [64, 128, 256, 512, 1024], - enc_blk_nums = [1,1,3,4], - middle_blk_num = 6, - dec_blk_nums = [2, 2, 1, 1], - cond_input = 1, - in_channels = 3, - out_channels = 3, - ) - - def forward(self, x, cond, residual): - output = self.model(x, cond) - return residual + output - - -def make_full_model(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw_full.pt'): - model = ModelWrapperFull() - if not model_name is None: - state_dict = torch.load(model_name, map_location="cpu") - model.load_state_dict(state_dict) - return model - - class ModelWrapperFullRGGB(nn.Module): - def __init__(self): + def __init__(self, **kwargs): super().__init__() self.model = Restorer( - chans = [32, 64, 128, 256, 256, 256], - enc_blk_nums = [2,2,2,3,4], - middle_blk_num = 12, - dec_blk_nums = [2, 2, 2, 2, 2], - cond_input = 1, - in_channels = 4, - out_channels = 3, - rggb=True, + **kwargs ) def forward(self, x, cond, residual): @@ -395,62 +337,18 @@ def forward(self, x, cond, residual): def make_full_model_RGGB(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw_full_RGGB.pt'): - model = ModelWrapperFullRGGB() + params = {"chans" : [32, 64, 128, 256, 256, 256], + "enc_blk_nums" : [2,2,2,3,4], + "middle_blk_num" : 12, + "dec_blk_nums" : [2, 2, 2, 2, 2], + "cond_input" : 1, + "in_channels" : 4, + "out_channels" : 3, + "rggb": True, + } + model = ModelWrapperFullRGGB(**params) if not model_name is None: state_dict = torch.load(model_name, map_location="cpu") model.load_state_dict(state_dict) - return model - + return model, params -class ModelWrapperResidual(nn.Module): - def __init__(self): - super().__init__() - self.model = Restorer( - chans = [32, 64, 128, 256, 256, 256], - enc_blk_nums = [2,2,2,3,4], - middle_blk_num = 12, - dec_blk_nums = [2, 2, 2, 2, 2], - cond_input = 1, - in_channels = 6, - out_channels = 3, - rggb=False, - ) - - def forward(self, x, cond, residual): - output = self.model(x, cond) - return residual + output - -def make_residual_model(model_name = None): - model = ModelWrapperResidual() - if not model_name is None: - state_dict = torch.load(model_name, map_location="cpu") - model.load_state_dict(state_dict) - return model - - - - -class ModelWrapperDeep(nn.Module): - def __init__(self): - super().__init__() - self.model = Restorer( - chans = [32, 64, 128, 256, 512, 1024], - enc_blk_nums = [2,2,2,3,4], - middle_blk_num = 6, - dec_blk_nums = [2, 2, 2, 2, 2], - cond_input = 1, - in_channels = 3, - out_channels = 3, - ) - - def forward(self, x, cond, residual): - output = self.model(x, cond) - return residual + output - - -def make_deep_model(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw_deep.pt'): - model = ModelWrapperDeep() - if not model_name is None: - state_dict = torch.load(model_name, map_location="cpu") - model.load_state_dict(state_dict) - return model \ No newline at end of file diff --git a/src/Restorer/Cond_NAF.py b/src/Restorer/Cond_NAF.py index 1372378..9fa29dc 100644 --- a/src/Restorer/Cond_NAF.py +++ b/src/Restorer/Cond_NAF.py @@ -361,123 +361,11 @@ def check_image_size(self, x): x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) return x - class ModelWrapper(nn.Module): - def __init__(self): - super().__init__() - self.model = Restorer( - chans = [32, 128, 256, 512, 1024], - enc_blk_nums = [1,1,3,4], - middle_blk_num = 6, - dec_blk_nums = [2, 2, 1, 1], - cond_input = 1, - in_channels = 3, - out_channels = 3, - ) - - def forward(self, x, cond, residual): - output = self.model(x, cond) - return residual + output - -def make_model(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw.pt'): - model = ModelWrapper() - if not model_name is None: - state_dict = torch.load(model_name, map_location="cpu") - model.load_state_dict(state_dict) - return model - -class ModelWrapperFull(nn.Module): - def __init__(self): - super().__init__() - self.model = Restorer( - chans = [64, 128, 256, 512, 1024], - enc_blk_nums = [1,1,3,4], - middle_blk_num = 6, - dec_blk_nums = [2, 2, 1, 1], - cond_input = 1, - in_channels = 3, - out_channels = 3, - ) - - def forward(self, x, cond, residual): - output = self.model(x, cond) - return residual + output - - -def make_full_model(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw_full.pt'): - model = ModelWrapperFull() - if not model_name is None: - state_dict = torch.load(model_name, map_location="cpu") - model.load_state_dict(state_dict) - return model - - -class ModelWrapperFullRGGB(nn.Module): - def __init__(self): - super().__init__() - self.model = Restorer( - chans = [32, 64, 128, 256, 256, 256], - enc_blk_nums = [2,2,2,3,4], - middle_blk_num = 12, - dec_blk_nums = [2, 2, 2, 2, 2], - cond_input = 1, - in_channels = 4, - out_channels = 3, - rggb=True, - ) - - def forward(self, x, cond, residual): - output = self.model(x, cond) - return residual + output - - -def make_full_model_RGGB(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw_full_RGGB.pt'): - model = ModelWrapperFullRGGB() - if not model_name is None: - state_dict = torch.load(model_name, map_location="cpu") - model.load_state_dict(state_dict) - return model - - -class ModelWrapperResidual(nn.Module): - def __init__(self): - super().__init__() - self.model = Restorer( - chans = [32, 64, 128, 256, 256, 256], - enc_blk_nums = [2,2,2,3,4], - middle_blk_num = 12, - dec_blk_nums = [2, 2, 2, 2, 2], - cond_input = 1, - in_channels = 6, - out_channels = 3, - rggb=False, - ) - - def forward(self, x, cond, residual): - output = self.model(x, cond) - return residual + output - -def make_residual_model(model_name = None): - model = ModelWrapperResidual() - if not model_name is None: - state_dict = torch.load(model_name, map_location="cpu") - model.load_state_dict(state_dict) - return model - - - - -class ModelWrapperDeep(nn.Module): - def __init__(self): + def __init__(self, **kwargs): super().__init__() self.model = Restorer( - chans = [32, 64, 128, 256, 512, 1024], - enc_blk_nums = [2,2,2,3,4], - middle_blk_num = 6, - dec_blk_nums = [2, 2, 2, 2, 2], - cond_input = 1, - in_channels = 3, - out_channels = 3, + **kwargs ) def forward(self, x, cond, residual): @@ -485,8 +373,8 @@ def forward(self, x, cond, residual): return residual + output -def make_deep_model(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw_deep.pt'): - model = ModelWrapperDeep() +def make_full_model_RGGB(params, model_name=None): + model = ModelWrapper(**params) if not model_name is None: state_dict = torch.load(model_name, map_location="cpu") model.load_state_dict(state_dict) diff --git a/src/training/SmallRawDataset.py b/src/training/SmallRawDataset.py index ff3d013..6bd8919 100644 --- a/src/training/SmallRawDataset.py +++ b/src/training/SmallRawDataset.py @@ -32,10 +32,10 @@ def __len__(self): def __getitem__(self, idx): row = self.df.iloc[idx] # Load images - with imageio.imopen(f"{row.bayer_path}", "r") as image_resource: + with imageio.imopen(self.path / f"{row.noisy_image}_bayer.jpg", "r") as image_resource: bayer_data = image_resource.read() - with imageio.imopen(f"{row.gt_path}", "r") as image_resource: + with imageio.imopen(self.path / f"{row.gt_image}.jpg", "r") as image_resource: gt_image = image_resource.read() gt_image = gt_image/255 bayer_data = bayer_data/255 diff --git a/src/training/align_images.py b/src/training/align_images.py index ddc08ec..b63f351 100644 --- a/src/training/align_images.py +++ b/src/training/align_images.py @@ -107,7 +107,7 @@ def safe_gray(img): # flatten warp matrix for CSV for i in range(2): for j in range(3): - metrics[f"M{i}{j}"] = float(M_total[i, j]) + metrics[f"m{i}{j}"] = float(M_total[i, j]) return aligned, M_total, metrics @@ -116,12 +116,12 @@ def safe_gray(img): def apply_alignment(img, warp_params, interpolation=cv2.INTER_LINEAR): """ Applies a previously estimated affine warp to an image. - warp_params: dict with keys M00..M12 or a 2x3 numpy array. + warp_params: dict with keys m00..m12 or a 2x3 numpy array. """ if isinstance(warp_params, dict): M = np.array([ - [warp_params["M00"], warp_params["M01"], warp_params["M02"]], - [warp_params["M10"], warp_params["M11"], warp_params["M12"]], + [warp_params["m00"], warp_params["m01"], warp_params["m02"]], + [warp_params["m10"], warp_params["m11"], warp_params["m12"]], ], dtype=np.float32) else: M = np.array(warp_params, dtype=np.float32) @@ -176,6 +176,6 @@ def __getitem__(self, idx): aligned, matrix, metrics = align_clean_to_noisy(gt_image, demosaiced_noisy, refine=False) metrics['iso'] = row.iso metrics['std'] = (demosaiced_noisy.astype(int) - aligned.astype(int)).std() - metrics['bayer_path'] = bayer_path - metrics['gt_path'] = gt_path + metrics['noisy_image'] = bayer_path + metrics['gt_image'] = gt_path return gt_image, demosaiced_noisy, aligned, metrics \ No newline at end of file diff --git a/src/Restorer/losses/MultiScaleLoss.py b/src/training/losses/MultiScaleLoss.py similarity index 100% rename from src/Restorer/losses/MultiScaleLoss.py rename to src/training/losses/MultiScaleLoss.py diff --git a/src/training/ShadowAwareLoss.py b/src/training/losses/ShadowAwareLoss.py similarity index 100% rename from src/training/ShadowAwareLoss.py rename to src/training/losses/ShadowAwareLoss.py diff --git a/src/Restorer/losses/__init__.py b/src/training/losses/__init__.py similarity index 100% rename from src/Restorer/losses/__init__.py rename to src/training/losses/__init__.py diff --git a/src/training/sparse_loop.py b/src/training/train_loop.py similarity index 84% rename from src/training/sparse_loop.py rename to src/training/train_loop.py index f64bba7..3b3296b 100644 --- a/src/training/sparse_loop.py +++ b/src/training/train_loop.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn from src.training.utils import apply_gamma_torch +import mlflow def make_conditioning(conditioning, device): B = conditioning.shape[0] @@ -12,29 +13,24 @@ def make_conditioning(conditioning, device): return conditioning_extended -def train_one_epoch(epoch, _model, _optimizer, _outname, _loader, _device, _loss_func, _clipping, log_interval = 10, sleep=0.2, rggb=False): +def train_one_epoch(epoch, _model, _optimizer, _loader, _device, _loss_func, _clipping, + log_interval = 10, sleep=0.0, rggb=False, + max_batches=0): _model.train() - total_loss, n_images, total_final_image_loss = 0.0, 0, 0.0 + total_loss, n_images, total_l1_loss = 0.0, 0, 0.0 start = perf_counter() pbar = tqdm(enumerate(_loader), total=len(_loader), desc=f"Train Epoch {epoch}") - # scaler = torch.amp.GradScaler(device_type="mps") # create once, outside training loop - for batch_idx, (output) in pbar: - # for output in train_loader: noisy = output['noisy'].float().to(_device) conditioning = output['conditioning'].float().to(_device) gt = output['aligned'].float().to(_device) input = output['sparse'].float().to(_device) if rggb: input = output['rggb'].float().to(_device) - conditioning = make_conditioning(conditioning, _device) - # with torch.autocast(device_type="mps", dtype=torch.bfloat16): - _optimizer.zero_grad(set_to_none=True) - # with torch.autocast(device_type="mps", dtype=torch.float16): pred = _model(input, conditioning, noisy) loss = _loss_func(pred, gt) @@ -48,22 +44,25 @@ def train_one_epoch(epoch, _model, _optimizer, _outname, _loader, _device, _loss # Testing final image quality final_image_loss = float(nn.functional.l1_loss(pred, gt).detach().cpu()) - total_final_image_loss += final_image_loss + total_l1_loss += final_image_loss del loss, pred, final_image_loss torch.mps.empty_cache() if (batch_idx + 1) % log_interval == 0: pbar.set_postfix({"loss": f"{total_loss/n_images:.4f}"}) + if (max_batches > 0) and (batch_idx+1 > max_batches): break time.sleep(sleep) - torch.save(_model.state_dict(), _outname) - + train_time = perf_counter()-start print(f"[Epoch {epoch}] " f"Train loss: {total_loss/n_images:.6f} " - f"Final image val loss: {total_final_image_loss/n_images:.6f} " - f"Time: {perf_counter()-start:.1f}s " + f"L1 loss: {total_l1_loss/n_images:.6f} " + f"Time: {train_time:.1f}s " f"Images: {n_images}") + mlflow.log_metric("train_loss", total_loss/n_images, step=epoch) + mlflow.log_metric("l1_loss", total_l1_loss/n_images, step=epoch) + mlflow.log_metric("epoch_duration_s", train_time, step=epoch) return total_loss / max(1, n_images), perf_counter()-start @@ -76,9 +75,7 @@ def visualize(idxs, _model, dataset, _device, _loss_func, rggb=False): total_loss, n_images, total_final_image_loss = 0.0, 0, 0.0 start = perf_counter() - for idx in idxs: - # for output in train_loader: row = dataset[idx] noisy = row['noisy'].unsqueeze(0).float().to(_device) conditioning = row['conditioning'].float().unsqueeze(0).to(_device) From 09cfeffc2a4b6de88144f0ecb0a761fb7139e9cf Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 16 Oct 2025 20:06:57 -0400 Subject: [PATCH 06/56] Update to readme --- README.md | 123 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 121 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 21cf958..db6bb96 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,122 @@ -# WIP +# Restorer -This is a WIP repo for testing conditioning a unet. More instructions to follow. \ No newline at end of file +## Overview + +This branch contains experimental code for a lightweight, local training workflow to aid in the development of models for **Raw Refinery**. The goal is to enable faster model iteration on low-powered local computers with limited SSD space when access to large GPUs and SSDs are limited. + +While functional, it remains a **work in progress**, primarily intended as a proof-of-concept for faster local experimentation. + +Because the models are designed to be computationally efficient, the main bottleneck in this workflow is **disk I/O**. Raw image files are high-resolution 12-bit images, too large for the laptop’s internal SSD but too slow to stream from an external HDD. To mitigate this, we generate a smaller, compressed dataset suitable for quick pretraining. + +--- + +## Producing the Smaller Dataset + +We pretrain using a compressed **8-bit version** of the noisy raw sensor data and their corresponding **preprocessed (demosaiced)** ground-truth images. + +After testing various formats, **JPEG** proved to be the most convenient and performant: + +* Faster to read than PNG or NumPy-based formats +* Smaller file sizes even at high quality settings +* Sufficiently preserves high-frequency detail in the Bayer sensor data + +*(A possible future improvement would be to store each of the four Bayer channels as separate JPEGs and merge them at runtime, but this implementation opts for simplicity.)* + +### Relevant Files + +* **`download_rawnind.sh`** + Downloads the [RawNIND dataset](https://doi.org/10.14428/DVN/DEQCIM) file-by-file, as it’s too large to fetch directly. You can place it in any directory where you wish the dataset to be stored. + +* **`config.yaml`** + Contains global configuration parameters including file paths, training hyperparameters, and general workflow settings. + +* **`0_produce_small_dataset.ipynb`** + Two-part notebook: + + 1. Iterates through raw images, aligns a demosaiced ground-truth image to each noisy image, and saves aligned pairs as JPEGs. + + * Alignment first uses a **feature-based method** for coarse alignment. + * Followed by an **ECC alignment** for sub-pixel precision. + * The **correlation coefficient** is saved to help identify poorly aligned images. + * Manual inspection is still recommended—alignment quality is critical for effective model training. + 2. Optionally crops the JPEGs to reduce storage requirements and speed up data loading. + +* **`0_align_images.ipynb`** + Re-runs image alignment. Useful for experimenting with alternative alignment techniques or metrics. Not required if the previous notebook has already aligned the dataset. + +--- + +## Pretraining + +With the small dataset prepared, we can begin **local pretraining**. + +Training hyperparameters are stored in `config.yaml`. Typical settings: + +* **Patch size:** 256×256 +* **Batch size:** 2 (for low-memory environments) +* **Optimizer:** Adam +* **Learning rate:** constant (can be scaled for larger GPUs) +* **Epochs:** 75 (sufficient for baseline performance); 150 preferred (~<24 hrs training time) + +MLflow is used to track runs, hyperparameters, and metrics. + +The dataset is split **80/20** into training and validation subsets. +Regularization methods like L2 weight decay, dropout, or strong augmentations often **harm** reconstruction performance; **random cropping** is typically sufficient. + +### Relevant Files + +* **`1_pretrain_model.ipynb`** + Main training notebook. Loads the model, initializes the optimizer, and runs the training loop. + MLflow integration records hyperparameters and results. + *(Validation is not yet integrated directly in the training loop.)* + **To-do:** Add model checkpointing. + +* **`1_validate_model.ipynb`** + Validates trained models. + Separated from the main loop to allow for manual visual inspection of images with different noise levels, which is often more informative than numerical loss metrics. + **To-do:** Save validation artifacts (sample images, metrics) to the MLflow run. + +* **`src/Restorer/Cond_NAF.py`** + Defines the neural network model. + +* **`src/training/train_loop.py`** + Implements the core training and validation loops. + +* **`src/training/losses/ShadowAwareLoss.py`** + Custom loss combining: + + * L1 loss + * Multi-scale SSIM loss + * Perceptual loss (VGG-like) + Designed to emphasize realistic and visually appealing reconstructions. + +--- + +## Fine-Tuning + +Once pretraining converges, we fine-tune the model on **real raw images**. +This step restores high-frequency detail lost during 8-bit compression. + +Fine-tuning uses similar hyperparameters to pretraining, but introduces **Cosine annealing** for smoother learning rate decay. + +This portion of the workflow is still will be added shortly. + +--- + +## Deployment + +To deploy a trained model, we **trace** it into a TorchScript module for seamless integration with the **Raw Refinery** application. + +This is handled in the **`3_make_script.ipynb`** notebook. + +--- + +## Acknowledgments + +> Brummer, Benoit; De Vleeschouwer, Christophe. (2025). +> *Raw Natural Image Noise Dataset.* +> [https://doi.org/10.14428/DVN/DEQCIM](https://doi.org/10.14428/DVN/DEQCIM), Open Data @ UCLouvain, V1. + +> Chen, Liangyu; Chu, Xiaojie; Zhang, Xiangyu; Chen, Jianhao. (2022). +> *NAFNet: Simple Baselines for Image Restoration.* +> [https://doi.org/10.48550/arXiv.2208.04677](https://doi.org/10.48550/arXiv.2208.04677), arXiv, V1. \ No newline at end of file From 5b1793fbd3f869f09b231e89e22146c44e85da73 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 16 Oct 2025 20:08:53 -0400 Subject: [PATCH 07/56] Updates to models --- src/Restorer/CombinedPerceptualLoss.py | 203 ------------------------- src/Restorer/Cond_NAF.py | 45 ++---- src/Restorer/Cond_NAF_ps.py | 12 -- 3 files changed, 11 insertions(+), 249 deletions(-) delete mode 100644 src/Restorer/CombinedPerceptualLoss.py diff --git a/src/Restorer/CombinedPerceptualLoss.py b/src/Restorer/CombinedPerceptualLoss.py deleted file mode 100644 index 75b3b07..0000000 --- a/src/Restorer/CombinedPerceptualLoss.py +++ /dev/null @@ -1,203 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision -from torchvision import models -from pytorch_msssim import ms_ssim -from pytorch_msssim.ssim import _ssim, _fspecial_gauss_1d -import lpips -import torch_dct as dct_2d -from kornia.color import rgb_to_ycbcr - - -# Losses - - -class FrequencyLoss(nn.Module): - def forward(self, output, target): - return F.l1_loss(dct_2d.dct_2d(output), dct_2d.dct_2d(target)) - - -class ChromaLoss(nn.Module): - def __init__(self, loss_func=F.l1_loss): - super().__init__() - self.loss_func = loss_func - - def forward(self, output, target): - return self.loss_func(rgb_to_ycbcr(output)[:, 1:], rgb_to_ycbcr(target)[:, 1:]) - - -class LumaLoss(nn.Module): - def __init__(self, loss_func=F.l1_loss): - super().__init__() - self.loss_func = loss_func - - def forward(self, output, target): - return self.loss_func(rgb_to_ycbcr(output)[:, :1], rgb_to_ycbcr(target)[:, :1]) - - -class VGGPerceptualLoss(nn.Module): - def __init__( - self, layers=[0, 5, 10, 19, 28], weights=[1.0] * 5, apply_gamma_curve=False - ): - super().__init__() - vgg = models.vgg16(weights=True).features[: max(layers) + 1] - for param in vgg.parameters(): - param.requires_grad = False - self.vgg = vgg.eval() - self.layers = layers - self.weights = weights - self.criterion = nn.L1Loss() - self.apply_gamma_curve = apply_gamma_curve - - def forward(self, pred, target): - if self.apply_gamma_curve: - pred = apply_gamma(pred) - target = apply_gamma(target) - loss = 0.0 - x = pred - y = target - wdix = 0 - for idx, layer in enumerate(self.vgg): - x = layer(x) - y = layer(y) - if idx in self.layers: - loss += self.weights[wdix] * self.criterion(x, y) - wdix += 1 - return loss - - -def gram_matrix(features): - """Compute the Gram matrix from feature maps.""" - B, C, H, W = features.size() - features = features.view(B, C, H * W) # Flatten spatial dimensions - gram = torch.bmm(features, features.transpose(1, 2)) # Batch matrix multiplication - gram = gram / (C * H * W) # Normalize - return gram - - -class StyleLoss(VGGPerceptualLoss): - def forward(self, pred, target): - if self.apply_gamma_curve: - pred = apply_gamma(pred) - target = apply_gamma(target) - loss = 0.0 - x = pred - y = target - wdix = 0 - for idx, layer in enumerate(self.vgg): - x = layer(x) - y = layer(y) - if idx in self.layers: - x_gram = gram_matrix(x) - y_gram = gram_matrix(y) - loss += self.weights[wdix] * self.criterion(x_gram, y_gram) - wdix += 1 - return loss - - -def transform_vgg(input): - mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) - std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) - return (input - mean) / std - - -def apply_gamma(_tensor): - tensor = _tensor.clone() - img_mask = tensor > 0.0031308 - tensor[img_mask] = 1.055 * torch.pow(tensor[img_mask], 1.0 / 2.4) - 0.055 - tensor[~img_mask] *= 12.92 - return tensor - - -def sharpness_loss(pred, target): - pred_blur = torchvision.transforms.functional.gaussian_blur(pred, 11, sigma=1.0) - target_blur = torchvision.transforms.functional.gaussian_blur(target, 11, sigma=1.0) - return F.l1_loss(pred - pred_blur, target - target_blur) - - -def conservative_l1(pred, inp, gt, sout=1, sin=1 / 10): - same_side = (pred - gt) * (inp - gt) > 0 - between = (pred - gt) * (pred - inp) < 0 - diff = (pred - gt).abs() - diff[between] *= sin - diff[~same_side] *= sout - diff[same_side] *= sout - diff[same_side * ~between] += (gt - inp).abs()[same_side * ~between] * (sin - 1) - return diff.mean() - - -def structural_loss(pred, gt): - win_size, win_sigma = 11, 1.5 - win = _fspecial_gauss_1d(win_size, win_sigma) - win = win.repeat([pred.shape[1]] + [1] * (len(pred.shape) - 1)) - _, cs = _ssim(pred, gt, data_range=1.0, win=win) - return 1 - cs.mean() - - -def gradient_loss(pred, gt): - pred_dx = pred[:, :, :, 1:] - pred[:, :, :, :-1] - gt_dx = gt[:, :, :, 1:] - gt[:, :, :, :-1] - pred_dy = pred[:, :, 1:, :] - pred[:, :, :-1, :] - gt_dy = gt[:, :, 1:, :] - gt[:, :, :-1, :] - return torch.mean(torch.abs(pred_dx - gt_dx)) + torch.mean( - torch.abs(pred_dy - gt_dy) - ) - - -# ==== Combined Loss Class ==== - - -class CombinedPerceptualLoss(nn.Module): - def __init__(self, **lambdas): - super().__init__() - self.loss_weights = lambdas - - # Register loss components - self.loss_modules = { - "l1": nn.L1Loss(), - "mse": nn.MSELoss(), - "ssim": lambda p, t: 1 - ms_ssim(p, t, data_range=1.0, size_average=True), - "ssim_coarse": lambda p, t: 1 - - ms_ssim( - p, - t, - data_range=1.0, - size_average=True, - weights=[0.1, 0.2, 0.448028, 0.352898, 0.199074], - ), - "ssim_fine": lambda p, t: 1 - - ms_ssim(p, t, data_range=1.0, size_average=True, weights=[1, 0, 0, 0, 0]), - "vgg": VGGPerceptualLoss(apply_gamma_curve=True), - "style": StyleLoss( - layers=[0, 5, 10, 17], - weights=[1.0, 1.0, 0.5, 0.25], - apply_gamma_curve=True, - ), - "frequency": FrequencyLoss(), - "chroma": ChromaLoss(), - "luma": LumaLoss(), - "luma_mse": LumaLoss(loss_func=F.mse_loss), - "luma_struct": LumaLoss(loss_func=structural_loss), - "lpips": lpips.LPIPS(net="vgg"), - "sharpness_loss": sharpness_loss, - } - - def forward(self, pred, target): - loss = 0.0 - for name, weight in self.loss_weights.items(): - if weight <= 0: - continue - if name == "invariant_l1": - mean_t = target.mean(dim=(1, 2, 3), keepdim=True) - mean_p = pred.mean(dim=(1, 2, 3), keepdim=True) - val = self.loss_modules[name](pred - mean_p, target - mean_t) - elif name == "lpips": - p = apply_gamma(pred).clamp(0, 1) * 2 - 1 - t = apply_gamma(target).clamp(0, 1) * 2 - 1 - val = self.loss_modules[name](p, t).mean() - else: - val = self.loss_modules[name](pred, target) - loss += weight * val - - return loss diff --git a/src/Restorer/Cond_NAF.py b/src/Restorer/Cond_NAF.py index 9fa29dc..f503564 100644 --- a/src/Restorer/Cond_NAF.py +++ b/src/Restorer/Cond_NAF.py @@ -2,47 +2,24 @@ import torch import torch.nn as nn -class LayerNormFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x, weight, bias, eps): - ctx.eps = eps - N, C, H, W = x.size() - mu = x.mean(1, keepdim=True) - var = (x - mu).pow(2).mean(1, keepdim=True) - y = (x - mu) / (var + eps).sqrt() - ctx.save_for_backward(y, var, weight) - y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) - return y - - @staticmethod - def backward(ctx, grad_output): - eps = ctx.eps - - N, C, H, W = grad_output.size() - y, var, weight = ctx.saved_variables - g = grad_output * weight.view(1, C, 1, 1) - mean_g = g.mean(dim=1, keepdim=True) - - mean_gy = (g * y).mean(dim=1, keepdim=True) - gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) - return ( - gx, - (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), - grad_output.sum(dim=3).sum(dim=2).sum(dim=0), - None, - ) - - class LayerNorm2d(nn.Module): def __init__(self, channels, eps=1e-6): super(LayerNorm2d, self).__init__() self.register_parameter("weight", nn.Parameter(torch.ones(channels))) self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) self.eps = eps - + def forward(self, x): - return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) - + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + + y = (x - mu) / torch.sqrt(var + self.eps) + + weight_view = self.weight.view(1, self.weight.size(0), 1, 1) + bias_view = self.bias.view(1, self.bias.size(0), 1, 1) + + y = weight_view * y + bias_view + return y class SimpleGate(nn.Module): def forward(self, x): diff --git a/src/Restorer/Cond_NAF_ps.py b/src/Restorer/Cond_NAF_ps.py index f7932ec..6c79a18 100644 --- a/src/Restorer/Cond_NAF_ps.py +++ b/src/Restorer/Cond_NAF_ps.py @@ -5,32 +5,20 @@ class LayerNorm2d(nn.Module): def __init__(self, channels, eps=1e-6): super(LayerNorm2d, self).__init__() - # 1. Keep the weight and bias as standard nn.Parameters self.register_parameter("weight", nn.Parameter(torch.ones(channels))) self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) self.eps = eps - # 2. REMOVE the self.weight_view and self.bias_view initializations from here - # They will be created dynamically in forward. - def forward(self, x): - # N, C, H, W = x.size() # While useful for clarity, not strictly needed for the operations - - # 1. Calculate Mean (mu) and Variance (var) across the Channel dimension (1) - # Note: We are sticking to your original normalization over C (dim=1) mu = x.mean(1, keepdim=True) var = (x - mu).pow(2).mean(1, keepdim=True) - # 2. Normalize the input y = (x - mu) / torch.sqrt(var + self.eps) - # 3. Create the views INSIDE the forward pass, so they are part of the traced graph weight_view = self.weight.view(1, self.weight.size(0), 1, 1) bias_view = self.bias.view(1, self.bias.size(0), 1, 1) - # 4. Apply the learnable scale (weight) and shift (bias) y = weight_view * y + bias_view - return y class SimpleGate(nn.Module): From 2e3c9808da38897c1c9cd6dcd645d1cf9667b7be Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 16 Oct 2025 20:09:05 -0400 Subject: [PATCH 08/56] Moved combined perceptual losses --- src/training/losses/CombinedPerceptualLoss.py | 203 ++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 src/training/losses/CombinedPerceptualLoss.py diff --git a/src/training/losses/CombinedPerceptualLoss.py b/src/training/losses/CombinedPerceptualLoss.py new file mode 100644 index 0000000..75b3b07 --- /dev/null +++ b/src/training/losses/CombinedPerceptualLoss.py @@ -0,0 +1,203 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +from torchvision import models +from pytorch_msssim import ms_ssim +from pytorch_msssim.ssim import _ssim, _fspecial_gauss_1d +import lpips +import torch_dct as dct_2d +from kornia.color import rgb_to_ycbcr + + +# Losses + + +class FrequencyLoss(nn.Module): + def forward(self, output, target): + return F.l1_loss(dct_2d.dct_2d(output), dct_2d.dct_2d(target)) + + +class ChromaLoss(nn.Module): + def __init__(self, loss_func=F.l1_loss): + super().__init__() + self.loss_func = loss_func + + def forward(self, output, target): + return self.loss_func(rgb_to_ycbcr(output)[:, 1:], rgb_to_ycbcr(target)[:, 1:]) + + +class LumaLoss(nn.Module): + def __init__(self, loss_func=F.l1_loss): + super().__init__() + self.loss_func = loss_func + + def forward(self, output, target): + return self.loss_func(rgb_to_ycbcr(output)[:, :1], rgb_to_ycbcr(target)[:, :1]) + + +class VGGPerceptualLoss(nn.Module): + def __init__( + self, layers=[0, 5, 10, 19, 28], weights=[1.0] * 5, apply_gamma_curve=False + ): + super().__init__() + vgg = models.vgg16(weights=True).features[: max(layers) + 1] + for param in vgg.parameters(): + param.requires_grad = False + self.vgg = vgg.eval() + self.layers = layers + self.weights = weights + self.criterion = nn.L1Loss() + self.apply_gamma_curve = apply_gamma_curve + + def forward(self, pred, target): + if self.apply_gamma_curve: + pred = apply_gamma(pred) + target = apply_gamma(target) + loss = 0.0 + x = pred + y = target + wdix = 0 + for idx, layer in enumerate(self.vgg): + x = layer(x) + y = layer(y) + if idx in self.layers: + loss += self.weights[wdix] * self.criterion(x, y) + wdix += 1 + return loss + + +def gram_matrix(features): + """Compute the Gram matrix from feature maps.""" + B, C, H, W = features.size() + features = features.view(B, C, H * W) # Flatten spatial dimensions + gram = torch.bmm(features, features.transpose(1, 2)) # Batch matrix multiplication + gram = gram / (C * H * W) # Normalize + return gram + + +class StyleLoss(VGGPerceptualLoss): + def forward(self, pred, target): + if self.apply_gamma_curve: + pred = apply_gamma(pred) + target = apply_gamma(target) + loss = 0.0 + x = pred + y = target + wdix = 0 + for idx, layer in enumerate(self.vgg): + x = layer(x) + y = layer(y) + if idx in self.layers: + x_gram = gram_matrix(x) + y_gram = gram_matrix(y) + loss += self.weights[wdix] * self.criterion(x_gram, y_gram) + wdix += 1 + return loss + + +def transform_vgg(input): + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + return (input - mean) / std + + +def apply_gamma(_tensor): + tensor = _tensor.clone() + img_mask = tensor > 0.0031308 + tensor[img_mask] = 1.055 * torch.pow(tensor[img_mask], 1.0 / 2.4) - 0.055 + tensor[~img_mask] *= 12.92 + return tensor + + +def sharpness_loss(pred, target): + pred_blur = torchvision.transforms.functional.gaussian_blur(pred, 11, sigma=1.0) + target_blur = torchvision.transforms.functional.gaussian_blur(target, 11, sigma=1.0) + return F.l1_loss(pred - pred_blur, target - target_blur) + + +def conservative_l1(pred, inp, gt, sout=1, sin=1 / 10): + same_side = (pred - gt) * (inp - gt) > 0 + between = (pred - gt) * (pred - inp) < 0 + diff = (pred - gt).abs() + diff[between] *= sin + diff[~same_side] *= sout + diff[same_side] *= sout + diff[same_side * ~between] += (gt - inp).abs()[same_side * ~between] * (sin - 1) + return diff.mean() + + +def structural_loss(pred, gt): + win_size, win_sigma = 11, 1.5 + win = _fspecial_gauss_1d(win_size, win_sigma) + win = win.repeat([pred.shape[1]] + [1] * (len(pred.shape) - 1)) + _, cs = _ssim(pred, gt, data_range=1.0, win=win) + return 1 - cs.mean() + + +def gradient_loss(pred, gt): + pred_dx = pred[:, :, :, 1:] - pred[:, :, :, :-1] + gt_dx = gt[:, :, :, 1:] - gt[:, :, :, :-1] + pred_dy = pred[:, :, 1:, :] - pred[:, :, :-1, :] + gt_dy = gt[:, :, 1:, :] - gt[:, :, :-1, :] + return torch.mean(torch.abs(pred_dx - gt_dx)) + torch.mean( + torch.abs(pred_dy - gt_dy) + ) + + +# ==== Combined Loss Class ==== + + +class CombinedPerceptualLoss(nn.Module): + def __init__(self, **lambdas): + super().__init__() + self.loss_weights = lambdas + + # Register loss components + self.loss_modules = { + "l1": nn.L1Loss(), + "mse": nn.MSELoss(), + "ssim": lambda p, t: 1 - ms_ssim(p, t, data_range=1.0, size_average=True), + "ssim_coarse": lambda p, t: 1 + - ms_ssim( + p, + t, + data_range=1.0, + size_average=True, + weights=[0.1, 0.2, 0.448028, 0.352898, 0.199074], + ), + "ssim_fine": lambda p, t: 1 + - ms_ssim(p, t, data_range=1.0, size_average=True, weights=[1, 0, 0, 0, 0]), + "vgg": VGGPerceptualLoss(apply_gamma_curve=True), + "style": StyleLoss( + layers=[0, 5, 10, 17], + weights=[1.0, 1.0, 0.5, 0.25], + apply_gamma_curve=True, + ), + "frequency": FrequencyLoss(), + "chroma": ChromaLoss(), + "luma": LumaLoss(), + "luma_mse": LumaLoss(loss_func=F.mse_loss), + "luma_struct": LumaLoss(loss_func=structural_loss), + "lpips": lpips.LPIPS(net="vgg"), + "sharpness_loss": sharpness_loss, + } + + def forward(self, pred, target): + loss = 0.0 + for name, weight in self.loss_weights.items(): + if weight <= 0: + continue + if name == "invariant_l1": + mean_t = target.mean(dim=(1, 2, 3), keepdim=True) + mean_p = pred.mean(dim=(1, 2, 3), keepdim=True) + val = self.loss_modules[name](pred - mean_p, target - mean_t) + elif name == "lpips": + p = apply_gamma(pred).clamp(0, 1) * 2 - 1 + t = apply_gamma(target).clamp(0, 1) * 2 - 1 + val = self.loss_modules[name](p, t).mean() + else: + val = self.loss_modules[name](pred, target) + loss += weight * val + + return loss From 0e2f0b7ee21128416e7e8b773c28118f82c85e4f Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 16 Oct 2025 20:09:15 -0400 Subject: [PATCH 09/56] update to gitignore --- .gitignore | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 7b004e5..fd9e85c 100644 --- a/.gitignore +++ b/.gitignore @@ -191,4 +191,8 @@ cython_debug/ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data # refer to https://docs.cursor.com/context/ignore-files .cursorignore -.cursorindexingignore \ No newline at end of file +.cursorindexingignore + + +# MLrn +mlruns/ \ No newline at end of file From 9ed9ca1dc05dc1a6d0af230159ccad27d59e9bdc Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 16 Oct 2025 20:09:45 -0400 Subject: [PATCH 10/56] Update to small dataset to control colorspace --- 0_produce_small_dataset.ipynb | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/0_produce_small_dataset.ipynb b/0_produce_small_dataset.ipynb index 1c419a3..d1b07e9 100644 --- a/0_produce_small_dataset.ipynb +++ b/0_produce_small_dataset.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "23c4096e", "metadata": {}, "outputs": [], @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "5b6231f2", "metadata": {}, "outputs": [], @@ -57,13 +57,14 @@ "outpath = Path(run_config['jpeg_output_subdir'])\n", "alignment_csv = outpath / run_config['align_csv']\n", "outpath_cropped = run_config['cropped_jpeg_subdir']\n", + "colorspace = run_config['colorspace']\n", "\n", "file_list = os.listdir(raw_path)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "783efea9", "metadata": {}, "outputs": [], @@ -125,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "e7f13812", "metadata": {}, "outputs": [], @@ -196,7 +197,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "7d9504cd", "metadata": {}, "outputs": [], @@ -215,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "419b9857", "metadata": {}, "outputs": [], @@ -226,8 +227,8 @@ " and ECC (fine).\n", " \"\"\"\n", " # 1. Load raw files\n", - " noisy_handler = RawHandler(f'{path}/{noisy_fname}')\n", - " gt_handler = RawHandler(f'{path}/{gt_fname}')\n", + " noisy_handler = RawHandler(f'{path}/{noisy_fname}', colorspace=colorspace)\n", + " gt_handler = RawHandler(f'{path}/{gt_fname}', colorspace=colorspace)\n", "\n", " noisy_bayer = noisy_handler.apply_colorspace_transform(colorspace='lin_rec2020', clip=True).astype(np.float32)\n", " gt_bayer = gt_handler.apply_colorspace_transform(colorspace='lin_rec2020', clip=True).astype(np.float32)\n", From 1855fa0e13bcb3cbdc3c3794134e4f78b0b3af92 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 16 Oct 2025 20:09:54 -0400 Subject: [PATCH 11/56] Added yaml config --- config.yaml | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/config.yaml b/config.yaml index 870b70a..ed7d025 100644 --- a/config.yaml +++ b/config.yaml @@ -1,6 +1,45 @@ # config.yaml +# --- Paths --- base_data_dir: /Volumes/EasyStore/RAWNIND/ jpeg_output_subdir: /Users/ryanmueller/Pictures/JPEGs_high_quality cropped_jpeg_subdir: /Users/ryanmueller/Pictures/Cropped_JPEG_high_quality align_csv: align_data_high_quality.csv -secondary_align_csv: secondary_align.csv \ No newline at end of file +secondary_align_csv: secondary_align.csv +script_path: 'traces_and_weights' + +# --- Training Params --- +colorspace: lin_rec2020 +device: mps +batch_size: 2 +crop_size: 256 +lr_base: 2.5e-5 +clipping: 1e-2 +num_epochs_pretraining: 75 +num_epochs_finetuning: 20 +val_split: 0.2 +random_seed: 42 + +# --- Experiment Settings --- +experiment: NAF_test +mlflow_experiment: NAFNet_variations + +# --- Run Configuration ---: +run_name: NAF_deep +run_path: NAF_deep +model_params: + chans: [32, 64, 128, 256, 256, 256] + enc_blk_nums: [2, 2, 2, 3, 4] + middle_blk_num: 12 + dec_blk_nums: [2, 2, 2, 2, 2] + cond_input: 1 + in_channels: 4 + out_channels: 3 + rggb: True + +# --- Loss Configureation ---: +alpha: 0.2 +beta: 5.0 +l1_weight: 0.16 +ssim_weight: 0.84 +tv_weight: 0.0 +vgg_loss_weight: 0 \ No newline at end of file From ce39746d26c3f63cd00cd755c253d6abff1849b5 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 16 Oct 2025 20:10:17 -0400 Subject: [PATCH 12/56] Flag to control number of batches for quick testing --- 1_pretrain_model.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/1_pretrain_model.ipynb b/1_pretrain_model.ipynb index a737238..e092d15 100644 --- a/1_pretrain_model.ipynb +++ b/1_pretrain_model.ipynb @@ -152,7 +152,7 @@ " mlflow.log_params(params)\n", " for epoch in range(num_epochs):\n", " train_one_epoch(epoch, model, optimizer, train_loader, device, loss_fn, clipping, \n", - " log_interval = 10, sleep=0.0, rggb=rggb, max_batches=2)\n", + " log_interval = 10, sleep=0.0, rggb=rggb, max_batches=0)\n", " \n", " mlflow.pytorch.log_model(\n", " pytorch_model=model,\n", From 947ccfee9355d9566efc6edabff986fca0aad279 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 16 Oct 2025 20:10:34 -0400 Subject: [PATCH 13/56] Notebook to make script file --- 3_make_script.ipynb | 124 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 3_make_script.ipynb diff --git a/3_make_script.ipynb b/3_make_script.ipynb new file mode 100644 index 0000000..624036e --- /dev/null +++ b/3_make_script.ipynb @@ -0,0 +1,124 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "685f1f4b", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from pathlib import Path\n", + "import mlflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4c87e77", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.load_config import load_config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72145448", + "metadata": {}, + "outputs": [], + "source": [ + "run_config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f499abe", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "output = Path(run_config['script_path']) / f\"{run_config['run_path']}.pt\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5eb1679", + "metadata": {}, + "outputs": [], + "source": [ + "RUN_ID = \"66f11f4639f24e9ea75b4c953147be15\" \n", + "ARTIFACT_PATH = run_config['run_path']\n", + "\n", + "model_uri = f\"runs:/{RUN_ID}/{ARTIFACT_PATH}\"\n", + "\n", + "try:\n", + " model = mlflow.pytorch.load_model(model_uri)\n", + " model.eval()\n", + " model.to('cpu')\n", + " print(f\"Model successfully loaded from MLflow URI: {model_uri}\")\n", + " \n", + "\n", + "except Exception as e:\n", + " print(f\"Error loading model from MLflow: {e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "913f088c", + "metadata": {}, + "outputs": [], + "source": [ + "run_config['crop_size']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1e4397b", + "metadata": {}, + "outputs": [], + "source": [ + "in_channels = run_config['model_params']['in_channels']\n", + "out_channels = run_config['model_params']['out_channels']\n", + "cond_input = run_config['model_params']['cond_input']\n", + "crop_size = run_config['crop_size']\n", + "\n", + "if run_config['model_params']['rggb']:\n", + " input = torch.rand(1, in_channels, crop_size//2, crop_size//2)\n", + "else:\n", + " input = torch.rand(1, in_channels, crop_size, crop_size)\n", + "\n", + "cond = torch.rand(1, cond_input)\n", + "residual = torch.rand(1, out_channels, crop_size, crop_size)\n", + "traced_script_module = torch.jit.trace(model, (input, cond, residual))\n", + "traced_script_module.save(output)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 52ae6721504caf5b7d29e892bdc7052c08df72fc Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Tue, 21 Oct 2025 11:22:08 -0400 Subject: [PATCH 14/56] Merging phase alignment output to be consistent with original alignment csv --- src/training/align_images.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/training/align_images.py b/src/training/align_images.py index b63f351..f24ad0e 100644 --- a/src/training/align_images.py +++ b/src/training/align_images.py @@ -15,6 +15,7 @@ import numpy as np import torch import cv2 +from pathlib import Path @@ -136,9 +137,6 @@ def apply_alignment(img, warp_params, interpolation=cv2.INTER_LINEAR): ) return aligned - - - class AlignImages(Dataset): def __init__(self, path, csv, crop_size=180, buffer=10, validation=False): super().__init__() @@ -176,6 +174,6 @@ def __getitem__(self, idx): aligned, matrix, metrics = align_clean_to_noisy(gt_image, demosaiced_noisy, refine=False) metrics['iso'] = row.iso metrics['std'] = (demosaiced_noisy.astype(int) - aligned.astype(int)).std() - metrics['noisy_image'] = bayer_path - metrics['gt_image'] = gt_path + metrics['bayer_path'] = Path(bayer_path).name + metrics['gt_path'] = Path(gt_path).name return gt_image, demosaiced_noisy, aligned, metrics \ No newline at end of file From e076f88cb6f6c553f80313628248a32aea5c9d35 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Tue, 21 Oct 2025 11:22:29 -0400 Subject: [PATCH 15/56] Output both train loss and l1 tracking loss in validation loop --- src/training/train_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/training/train_loop.py b/src/training/train_loop.py index 3b3296b..7c75292 100644 --- a/src/training/train_loop.py +++ b/src/training/train_loop.py @@ -130,4 +130,4 @@ def visualize(idxs, _model, dataset, _device, _loss_func, rggb=False): f"Time: {perf_counter()-start:.1f}s " f"Images: {n_images}") - return total_loss / max(1, n_images), perf_counter()-start \ No newline at end of file + return total_loss / max(1, n_images), total_final_image_loss / max(1, n_images), perf_counter()-start \ No newline at end of file From 92cf9ba8d4452e852e350cc8a9283d89881b5579 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Tue, 21 Oct 2025 11:23:56 -0400 Subject: [PATCH 16/56] Make small raw dataset more robust to input path and added option to rerun alignment --- src/training/SmallRawDataset.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/training/SmallRawDataset.py b/src/training/SmallRawDataset.py index 6bd8919..50fd962 100644 --- a/src/training/SmallRawDataset.py +++ b/src/training/SmallRawDataset.py @@ -12,12 +12,11 @@ from src.training.utils import inverse_gamma_tone_curve, cfa_to_sparse import numpy as np import torch -import cv2 -from src.training.align_images import apply_alignment - +from src.training.align_images import apply_alignment, align_clean_to_noisy +from pathlib import Path class SmallRawDataset(Dataset): - def __init__(self, path, csv, crop_size=180, buffer=10, validation=False): + def __init__(self, path, csv, crop_size=180, buffer=10, validation=False, run_align=False): super().__init__() self.df = pd.read_csv(csv) self.path = path @@ -25,24 +24,30 @@ def __init__(self, path, csv, crop_size=180, buffer=10, validation=False): self.buffer = buffer self.coordinate_iso = 6400 self.validation=validation - + self.run_align = run_align def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] # Load images - with imageio.imopen(self.path / f"{row.noisy_image}_bayer.jpg", "r") as image_resource: + with imageio.imopen(self.path / Path(f"{row.bayer_path}").name, "r") as image_resource: bayer_data = image_resource.read() - with imageio.imopen(self.path / f"{row.gt_image}.jpg", "r") as image_resource: + with imageio.imopen(self.path / Path(f"{row.gt_path}").name, "r") as image_resource: gt_image = image_resource.read() gt_image = gt_image/255 bayer_data = bayer_data/255 - - aligned = apply_alignment(gt_image, row.to_dict()) demosaiced_noisy = demosaicing_CFA_Bayer_Malvar2004(bayer_data) + if self.run_align: + gt_image = (gt_image * 255).astype(np.uint8) + demosaiced_noisy = (demosaiced_noisy * 255).astype(np.uint8) + aligned, _, _ = align_clean_to_noisy(gt_image, demosaiced_noisy, refine=False, verbose=False) + aligned = aligned / 255 + else: + aligned = apply_alignment(gt_image, row.to_dict()) + h, w, _ = gt_image.shape #Crop images From 1a1306694131ffbaed2bf97083b542427a36841f Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Tue, 21 Oct 2025 11:25:17 -0400 Subject: [PATCH 17/56] Updated config to reflect new alignment output --- config.yaml | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/config.yaml b/config.yaml index ed7d025..a3aaf27 100644 --- a/config.yaml +++ b/config.yaml @@ -1,11 +1,12 @@ # config.yaml # --- Paths --- base_data_dir: /Volumes/EasyStore/RAWNIND/ -jpeg_output_subdir: /Users/ryanmueller/Pictures/JPEGs_high_quality -cropped_jpeg_subdir: /Users/ryanmueller/Pictures/Cropped_JPEG_high_quality -align_csv: align_data_high_quality.csv -secondary_align_csv: secondary_align.csv -script_path: 'traces_and_weights' +jpeg_output_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG +cropped_jpeg_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG +align_csv: align_data.csv +secondary_align_csv: align_phase_v2.csv +script_path: traces_and_weights +mlflow_path: /Volumes/EasyStore/models/mlfow # --- Training Params --- colorspace: lin_rec2020 @@ -14,7 +15,7 @@ batch_size: 2 crop_size: 256 lr_base: 2.5e-5 clipping: 1e-2 -num_epochs_pretraining: 75 +num_epochs_pretraining: 50 num_epochs_finetuning: 20 val_split: 0.2 random_seed: 42 @@ -24,8 +25,8 @@ experiment: NAF_test mlflow_experiment: NAFNet_variations # --- Run Configuration ---: -run_name: NAF_deep -run_path: NAF_deep +run_name: NAF_deep_test_align +run_path: NAF_deep_test_align model_params: chans: [32, 64, 128, 256, 256, 256] enc_blk_nums: [2, 2, 2, 3, 4] From 43a34c1e8a3a0d74ae48332ad4c02c8a6676b224 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Tue, 21 Oct 2025 23:23:07 -0400 Subject: [PATCH 18/56] Experimental uncompressed dataset --- 0_produce_small_dataset_raw.ipynb | 230 +++++++++++++++++++++++++++ config.yaml | 6 +- src/training/SmallRawDatasetNumpy.py | 99 ++++++++++++ 3 files changed, 333 insertions(+), 2 deletions(-) create mode 100644 0_produce_small_dataset_raw.ipynb create mode 100644 src/training/SmallRawDatasetNumpy.py diff --git a/0_produce_small_dataset_raw.ipynb b/0_produce_small_dataset_raw.ipynb new file mode 100644 index 0000000..3ae0532 --- /dev/null +++ b/0_produce_small_dataset_raw.ipynb @@ -0,0 +1,230 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "23c4096e", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import imageio\n", + "import cv2\n", + "import matplotlib.pyplot as plt\n", + "import os\n", + "from pathlib import Path\n", + "import re\n", + "from PIL import Image\n", + "import re\n", + "from collections import defaultdict\n", + "\n", + "from colour_demosaicing import (\n", + " ROOT_RESOURCES_EXAMPLES,\n", + " demosaicing_CFA_Bayer_bilinear,\n", + " demosaicing_CFA_Bayer_Malvar2004,\n", + " demosaicing_CFA_Bayer_Menon2007,\n", + " mosaicing_CFA_Bayer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b6231f2", + "metadata": {}, + "outputs": [], + "source": [ + "from RawHandler.RawHandler import RawHandler\n", + "from RawHandler.utils import linear_to_srgb\n", + "from src.training.load_config import load_config\n", + "\n", + "def apply_gamma(x, gamma=2.2):\n", + " return x ** (1 / gamma)\n", + "\n", + "def reverse_gamma(x, gamma=2.2):\n", + " return x ** gamma" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2e21592", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "raw_path = Path(run_config['base_data_dir'])\n", + "outpath = Path(run_config['cropped_raw_subdir'])\n", + "alignment_csv = outpath / run_config['align_csv']\n", + "colorspace = run_config['colorspace']\n", + "crop_size = run_config['cropped_raw_size']\n", + "file_list = os.listdir(raw_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "783efea9", + "metadata": {}, + "outputs": [], + "source": [ + "def pair_images_by_scene(file_list, min_iso=100):\n", + " \"\"\"\n", + " Given a list of RAW image file paths:\n", + " 1. Extract ISO from filenames\n", + " 2. Remove files with ISO < min_iso\n", + " 3. Group by scene name\n", + " 4. Pair each image with the lowest-ISO version of the scene\n", + "\n", + " Args:\n", + " file_list (list of str): Paths to RAW files\n", + " min_iso (int): Minimum ISO to keep (default=100)\n", + "\n", + " Returns:\n", + " dict: {scene_name: [(img_path, gt_path), ...]}\n", + " \"\"\"\n", + " iso_pattern = re.compile(r\"_ISO(\\d+)_\")\n", + " scene_pairs = {}\n", + "\n", + " # Step 1: Extract iso and scene\n", + " images = []\n", + " for path in file_list:\n", + " filename = os.path.basename(path)\n", + " match = iso_pattern.search(filename)\n", + " if not match:\n", + " continue # skip if no ISO\n", + " iso = int(match.group(1))\n", + " if iso < min_iso:\n", + " continue # filter out low ISOs\n", + "\n", + " # Extract scene name:\n", + " if \"_GT_\" in filename:\n", + " scene = filename.split(\"_GT_\")[0]\n", + " else:\n", + " # Scene = part before \"_ISO\"\n", + " scene = filename.split(\"_ISO\")[0]\n", + " if 'X-Trans' in filename:\n", + " continue\n", + "\n", + " images.append((scene, iso, path))\n", + "\n", + " # Step 2: Group by scene\n", + " grouped = defaultdict(list)\n", + " for scene, iso, path in images:\n", + " grouped[scene].append((iso, path))\n", + "\n", + " # Step 3: For each scene, pick lowest ISO as GT\n", + " for scene, iso_paths in grouped.items():\n", + " iso_paths.sort(key=lambda x: x[0]) # sort by ISO ascending\n", + " gt_iso, gt_path = iso_paths[0] # lowest ISO ≥ min_iso\n", + " pairs = [(path, gt_path) for iso, path in iso_paths if path != gt_path]\n", + " scene_pairs[scene] = pairs\n", + "\n", + " return scene_pairs\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9af686c6", + "metadata": {}, + "outputs": [], + "source": [ + "pair_file_list = pair_images_by_scene(file_list)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1213a101", + "metadata": {}, + "outputs": [], + "source": [ + "def get_file(impath, crop_size=crop_size):\n", + " rh = RawHandler(impath)\n", + " \n", + " width, height = rh.raw.shape\n", + "\n", + " # Check if the image is large enough to be cropped.\n", + " if width < crop_size or height < crop_size:\n", + " im = rh.apply_colorspace_transform( colorspace=colorspace)\n", + " else:\n", + " # Calculate the coordinates for the center crop.\n", + " left = (width - crop_size) // 2\n", + " top = (height - crop_size) // 2\n", + "\n", + " # Ensure the top-left corner is on an even pixel coordinate for Bayer alignment.\n", + " if left % 2 != 0:\n", + " left -= 1\n", + " if top % 2 != 0:\n", + " top -= 1\n", + " \n", + " # Calculate the bottom-right corner based on the adjusted top-left corner.\n", + " # Since crop_size is even, right and bottom will also be even.\n", + " right = left + crop_size\n", + " bottom = top + crop_size\n", + "\n", + " im = rh.apply_colorspace_transform(dims=(left, right, top, bottom), colorspace=colorspace)\n", + " im_scaled = im * 65535.0\n", + " im_clipped = np.clip(im_scaled, 0.0, 65535.0)\n", + "\n", + " im_uint16 = im_clipped.astype(np.uint16)\n", + " return im_uint16" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "010774f5", + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c0c0816", + "metadata": {}, + "outputs": [], + "source": [ + "for key in tqdm(pair_file_list.keys()):\n", + " image_pairs = pair_file_list[key]\n", + " for noisy, gt in image_pairs:\n", + " try:\n", + " noisy_bayer = get_file(f'{raw_path}/{noisy}')\n", + " noisy_path = outpath / (noisy + \".u16.raw\")\n", + " noisy_bayer.tofile(noisy_path)\n", + "\n", + " gt_path = outpath / (gt + \".u16.raw\")\n", + " if not os.path.exists(gt_path):\n", + " gt_bayer = get_file(f'{raw_path}/{gt}')\n", + " gt_bayer.tofile(gt_path)\n", + " except:\n", + " print(f\"Skipping {raw_path}/{noisy}, {raw_path}/{gt}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/config.yaml b/config.yaml index a3aaf27..8df6c47 100644 --- a/config.yaml +++ b/config.yaml @@ -3,6 +3,8 @@ base_data_dir: /Volumes/EasyStore/RAWNIND/ jpeg_output_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG cropped_jpeg_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG +cropped_raw_subdir: /Volumes/EasyStore/RAWNIND/Cropped_Raw +cropped_raw_size: 2000 align_csv: align_data.csv secondary_align_csv: align_phase_v2.csv script_path: traces_and_weights @@ -25,8 +27,8 @@ experiment: NAF_test mlflow_experiment: NAFNet_variations # --- Run Configuration ---: -run_name: NAF_deep_test_align -run_path: NAF_deep_test_align +run_name: NAF_deep_test_align_raw +run_path: NAF_deep_test_align_raw model_params: chans: [32, 64, 128, 256, 256, 256] enc_blk_nums: [2, 2, 2, 3, 4] diff --git a/src/training/SmallRawDatasetNumpy.py b/src/training/SmallRawDatasetNumpy.py new file mode 100644 index 0000000..806d67c --- /dev/null +++ b/src/training/SmallRawDatasetNumpy.py @@ -0,0 +1,99 @@ +import pandas as pd +import os +from torch.utils.data import Dataset +import imageio +from colour_demosaicing import ( + ROOT_RESOURCES_EXAMPLES, + demosaicing_CFA_Bayer_bilinear, + demosaicing_CFA_Bayer_Malvar2004, + demosaicing_CFA_Bayer_Menon2007, + mosaicing_CFA_Bayer) + +from src.training.utils import inverse_gamma_tone_curve, cfa_to_sparse +import numpy as np +import torch +from src.training.align_images import apply_alignment, align_clean_to_noisy +from pathlib import Path + +class SmallRawDatasetNumpy(Dataset): + def __init__(self, path, csv, crop_size=180, buffer=10, validation=False, run_align=False, dimensions=2000): + super().__init__() + self.df = pd.read_csv(csv) + self.path = path + self.crop_size = crop_size + self.buffer = buffer + self.coordinate_iso = 6400 + self.validation=validation + self.run_align = run_align + self.dtype = np.uint16 + self.dimensions = dimensions + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + # Load images + + name = Path(f"{row.bayer_path}").name + name = name.replace('_bayer.jpg', '.u16.raw') + bayer_data = np.fromfile(self.path / name, dtype=self.dtype) + bayer_data = bayer_data.reshape((self.dimensions, self.dimensions)) + + name = Path(f"{row.gt_path}").name + name = name.replace('jpg', 'u16.raw') + gt_image = np.fromfile(self.path / name, dtype=self.dtype) + gt_image = gt_image.reshape((self.dimensions, self.dimensions)) + + + + gt_image = gt_image/65535 + bayer_data = bayer_data/65535 + gt_image = demosaicing_CFA_Bayer_Malvar2004(gt_image) + demosaiced_noisy = demosaicing_CFA_Bayer_Malvar2004(bayer_data) + + if self.run_align: + gt_image = (gt_image * 255).astype(np.uint8) + demosaiced_noisy = (demosaiced_noisy * 255).astype(np.uint8) + aligned, _, _ = align_clean_to_noisy(gt_image, demosaiced_noisy, refine=False, verbose=False) + aligned = aligned / 255 + else: + aligned = apply_alignment(gt_image, row.to_dict()) + + h, w, _ = gt_image.shape + + #Crop images + if not self.validation: + top = np.random.randint(0 + self.buffer, h - self.crop_size - self.buffer) + left = np.random.randint(0 + self.buffer, w - self.crop_size - self.buffer) + else: + top = (h - self.crop_size) // 2 + left = (w - self.crop_size) // 2 + + if top % 2 != 0: top = top - 1 + if left % 2 != 0: left = left - 1 + bottom = top + self.crop_size + right = left + self.crop_size + aligned = aligned[top:bottom, left:right] + gt_image = gt_image[top:bottom, left:right] + bayer_data = bayer_data[top:bottom, left:right] + h, w, _ = gt_image.shape + + demosaiced_noisy = demosaicing_CFA_Bayer_Malvar2004(bayer_data) + + aligned = aligned * demosaiced_noisy.mean() / aligned.mean() + gt_image = gt_image * demosaiced_noisy.mean() / gt_image.mean() + + sparse, _ = cfa_to_sparse(bayer_data) + rggb = bayer_data.reshape(h // 2, 2, w // 2, 2, 1).transpose(1, 3, 4, 0, 2).reshape(4, h // 2, w // 2) + + # Convert to tensors + output = { + "bayer": torch.tensor(bayer_data).to(float).clip(0,1), + "gt": torch.tensor(gt_image).to(float).permute(2, 0, 1).clip(0,1), + "aligned": torch.tensor(aligned).to(float).permute(2, 0, 1).clip(0,1), + "sparse": torch.tensor(sparse).to(float).clip(0,1), + "noisy": torch.tensor(demosaiced_noisy).to(float).permute(2, 0, 1).clip(0,1), + "rggb": torch.tensor(rggb).to(float).clip(0,1), + "conditioning": torch.tensor([row.iso/self.coordinate_iso]).to(float), + } + return output \ No newline at end of file From 922d07ca5aa29da3b8773a60d8b691c910f85c30 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 23 Oct 2025 10:17:44 -0400 Subject: [PATCH 19/56] Wider before deeper config --- config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config.yaml b/config.yaml index 8df6c47..61acb29 100644 --- a/config.yaml +++ b/config.yaml @@ -27,10 +27,10 @@ experiment: NAF_test mlflow_experiment: NAFNet_variations # --- Run Configuration ---: -run_name: NAF_deep_test_align_raw -run_path: NAF_deep_test_align_raw +run_name: NAF_deep_wide_2nd +run_path: NAF_deep_wide_2nd model_params: - chans: [32, 64, 128, 256, 256, 256] + chans: [32, 128, 256, 256, 256, 256] enc_blk_nums: [2, 2, 2, 3, 4] middle_blk_num: 12 dec_blk_nums: [2, 2, 2, 2, 2] From b70414b19ad26af27d438b0ef1d60b46f8ed0410 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 23 Oct 2025 20:09:09 -0400 Subject: [PATCH 20/56] Adding to fuse --- config.yaml | 8 ++-- src/Restorer/Cond_NAF.py | 99 ++++++++++++++++++++++++++++++++-------- 2 files changed, 85 insertions(+), 22 deletions(-) diff --git a/config.yaml b/config.yaml index 61acb29..63d97cd 100644 --- a/config.yaml +++ b/config.yaml @@ -27,10 +27,10 @@ experiment: NAF_test mlflow_experiment: NAFNet_variations # --- Run Configuration ---: -run_name: NAF_deep_wide_2nd -run_path: NAF_deep_wide_2nd +run_name: NAF_deep_CondFuserAdd +run_path: NAF_deep_CondFuserAdd model_params: - chans: [32, 128, 256, 256, 256, 256] + chans: [32, 64, 128, 256, 256, 256] enc_blk_nums: [2, 2, 2, 3, 4] middle_blk_num: 12 dec_blk_nums: [2, 2, 2, 2, 2] @@ -38,6 +38,8 @@ model_params: in_channels: 4 out_channels: 3 rggb: True + use_CondFuserV2: False + use_add: True # --- Loss Configureation ---: alpha: 0.2 diff --git a/src/Restorer/Cond_NAF.py b/src/Restorer/Cond_NAF.py index f503564..3c0074c 100644 --- a/src/Restorer/Cond_NAF.py +++ b/src/Restorer/Cond_NAF.py @@ -31,12 +31,6 @@ class ConditionedChannelAttention(nn.Module): def __init__(self, dims, cat_dims): super().__init__() in_dim = dims + cat_dims - # self.mlp = nn.Sequential( - # nn.Linear(in_dim, int(in_dim*1.5)), - # nn.GELU(), - # nn.Dropout(0.2), - # nn.Linear(int(in_dim*1.5), dims) - # ) self.mlp = nn.Sequential(nn.Linear(in_dim, dims)) self.pool = nn.AdaptiveAvgPool2d(1) @@ -49,6 +43,76 @@ def forward(self, x, conditioning): return ca +class CondFuser(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + # self.spa = nn.Conv2d( + # in_channels=chan * 2, + # out_channels=1, + # kernel_size=3, + # padding=1, + # stride=1, + # groups=1, + # bias=True, + # ) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + # spa = torch.sigmoid(self.spa(x)) + + x1, x2 = x.chunk(2, dim=1) + # return x1 * spa + x2 * (1 - spa) + return x1 + x2 + + +class NKA(nn.Module): + def __init__(self, dim, channel_reduction = 8): + super().__init__() + + reduced_channels = dim // channel_reduction + self.proj_1 = nn.Conv2d(dim, reduced_channels, 1, 1, 0) + self.dwconv = nn.Conv2d(reduced_channels, reduced_channels, 3, 1, 1, groups=reduced_channels) + self.proj_2 = nn.Conv2d(reduced_channels, reduced_channels * 2, 1, 1, 0) + self.sg = SimpleGate() + self.attention = nn.Conv2d(reduced_channels, dim, 1, 1, 0) + + def forward(self, x): + B, C, H, W = x.shape + # First projection to a smaller dimension + y = self.proj_1(x) + # DW conv + attn = self.dwconv(y) + # PW to increase channel count for SG + attn = self.proj_2(attn) + # Non-linearity + attn = self.sg(attn) + # Back to original dimensions + out = x * self.attention(attn) + return out + +class CondFuserAdd(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + + def forward(self, x1, x2, cond): + return x1 + x2 + +class CondFuserV2(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + self.spa = NKA(chan * 2) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + spa = torch.sigmoid(self.spa(x)) * x + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + class NAFBlock0(nn.Module): def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): @@ -169,18 +233,7 @@ def forward(self, x, conditioning): y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x) -class CondFuser(nn.Module): - def __init__(self, chan, cond_chan=1): - super().__init__() - self.cca = ConditionedChannelAttention(chan * 2, cond_chan) - def forward(self, x1, x2, cond): - x = torch.cat([x1, x2], dim=1) - x = self.cca(x, cond) * x - x1, x2 = x.chunk(2, dim=1) - return x1 + x2 - - class Restorer(nn.Module): def __init__( self, @@ -195,7 +248,9 @@ def __init__( expand_dims=2, drop_out_rate=0.0, drop_out_rate_increment=0.0, - rggb = False + rggb = False, + use_CondFuserV2 = False, + use_add = False ): super().__init__() width = chans[0] @@ -290,7 +345,13 @@ def __init__( ) ) drop_out_rate -= drop_out_rate_increment - self.merges.append(CondFuser(next_chan, cond_chan=cond_output)) + if use_CondFuserV2: + self.merges.append(CondFuserV2(next_chan, cond_chan=cond_output)) + elif use_add: + self.merges.append(CondFuserAdd(next_chan, cond_chan=cond_output)) + else: + self.merges.append(CondFuser(next_chan, cond_chan=cond_output)) + self.decoders.append( nn.Sequential( *[ From bed32bb79c84aa6fefef2a7f0ab40fa333eccf63 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 23 Oct 2025 20:49:31 -0400 Subject: [PATCH 21/56] Testing additive fusion --- 1_pretrain_model.ipynb | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/1_pretrain_model.ipynb b/1_pretrain_model.ipynb index e092d15..1ccfd02 100644 --- a/1_pretrain_model.ipynb +++ b/1_pretrain_model.ipynb @@ -43,7 +43,7 @@ "source": [ "run_config = load_config()\n", "dataset_path = Path(run_config['jpeg_output_subdir'])\n", - "align_csv = dataset_path / run_config['align_csv']" + "align_csv = dataset_path / run_config['secondary_align_csv']" ] }, { @@ -63,6 +63,8 @@ "val_split = run_config['val_split']\n", "crop_size = run_config['crop_size']\n", "experiment = run_config['mlflow_experiment']\n", + "mlflow_path = run_config['mlflow_path']\n", + "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", "mlflow.set_experiment(experiment)" ] }, @@ -160,6 +162,14 @@ " )" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f95b42f", + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, From 5e2d3cc561135e65e340c44133882f62d7e3e6ad Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 23 Oct 2025 20:50:03 -0400 Subject: [PATCH 22/56] Pretrain on raw swatches --- 1_pretrain_model_raw.ipynb | 197 +++++++++++++++++++++++++++++++++++++ 1 file changed, 197 insertions(+) create mode 100644 1_pretrain_model_raw.ipynb diff --git a/1_pretrain_model_raw.ipynb b/1_pretrain_model_raw.ipynb new file mode 100644 index 0000000..ec5273e --- /dev/null +++ b/1_pretrain_model_raw.ipynb @@ -0,0 +1,197 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6351e77", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader, random_split\n", + "import torch.nn as nn\n", + "import torch\n", + "import copy\n", + "import mlflow\n", + "import mlflow.pytorch\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2043dc8e", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.SmallRawDatasetNumpy import SmallRawDatasetNumpy\n", + "from src.training.losses.ShadowAwareLoss import ShadowAwareLoss\n", + "from src.training.VGGFeatureExtractor import VGGFeatureExtractor\n", + "from src.training.train_loop import train_one_epoch, visualize\n", + "from src.training.utils import apply_gamma_torch\n", + "from src.training.load_config import load_config\n", + "from src.Restorer.Cond_NAF import make_full_model_RGGB\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0464c98", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "dataset_path = Path(run_config['cropped_raw_subdir'])\n", + "align_csv = dataset_path / run_config['secondary_align_csv']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba20b866", + "metadata": {}, + "outputs": [], + "source": [ + "device=run_config['device']\n", + "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", + "\n", + "num_epochs = run_config['num_epochs_pretraining']\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "mlflow_path = run_config['mlflow_path']\n", + "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", + "mlflow.set_experiment(experiment)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a36eaef2", + "metadata": {}, + "outputs": [], + "source": [ + "model_params = run_config['model_params']\n", + "rggb = model_params['rggb']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f16fa7", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = SmallRawDatasetNumpy(dataset_path, align_csv, crop_size=crop_size)\n", + "\n", + "# Split dataset into train and val\n", + "val_size = int(len(dataset) * val_split)\n", + "train_size = len(dataset) - val_size\n", + "torch.manual_seed(42) # For reproducibility\n", + "train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n", + "# Set the validation dataset to use the same crops\n", + "val_dataset = copy.deepcopy(val_dataset)\n", + "val_dataset.dataset.validation = True\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "086fbb3a", + "metadata": {}, + "outputs": [], + "source": [ + "model = make_full_model_RGGB(model_params, model_name=None)\n", + "model = model.to(device)\n", + "\n", + "params = {**run_config, **model_params}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6af0f3a2", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", + "\n", + "vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), \n", + " feature_layers=[14], \n", + " activation=nn.ReLU\n", + " )\n", + "vfe = vfe.to(device)\n", + "\n", + "loss_fn = ShadowAwareLoss(\n", + " alpha=run_config['alpha'],\n", + " beta=run_config['beta'],\n", + " l1_weight=run_config['l1_weight'],\n", + " ssim_weight=run_config['ssim_weight'],\n", + " tv_weight=run_config['tv_weight'],\n", + " vgg_loss_weight=run_config['vgg_loss_weight'],\n", + " apply_gamma_fn=apply_gamma_torch,\n", + " vgg_feature_extractor=vfe,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bc86a06", + "metadata": {}, + "outputs": [], + "source": [ + "with mlflow.start_run(run_name=run_config['run_name']) as run:\n", + "\n", + " mlflow.log_params(params)\n", + " for epoch in range(num_epochs):\n", + " train_one_epoch(epoch, model, optimizer, train_loader, device, loss_fn, clipping, \n", + " log_interval = 10, sleep=0.0, rggb=rggb, max_batches=0)\n", + " \n", + " mlflow.pytorch.log_model(\n", + " pytorch_model=model,\n", + " name=run_config['run_path'],\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a775d31", + "metadata": {}, + "outputs": [], + "source": [ + "run.info.run_id" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 95149ced4a93fac4e0b13e6d259289b8e16017a4 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Sun, 26 Oct 2025 18:43:01 -0400 Subject: [PATCH 23/56] Code to train on cropped raw --- ...el_raw.ipynb => 2_pretrain_model_raw.ipynb | 43 +++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) rename 1_pretrain_model_raw.ipynb => 2_pretrain_model_raw.ipynb (90%) diff --git a/1_pretrain_model_raw.ipynb b/2_pretrain_model_raw.ipynb similarity index 90% rename from 1_pretrain_model_raw.ipynb rename to 2_pretrain_model_raw.ipynb index ec5273e..44de687 100644 --- a/1_pretrain_model_raw.ipynb +++ b/2_pretrain_model_raw.ipynb @@ -46,6 +46,16 @@ "align_csv = dataset_path / run_config['secondary_align_csv']" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "7322456f", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_path" + ] + }, { "cell_type": "code", "execution_count": null, @@ -64,6 +74,7 @@ "crop_size = run_config['crop_size']\n", "experiment = run_config['mlflow_experiment']\n", "mlflow_path = run_config['mlflow_path']\n", + "rggb = True\n", "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", "mlflow.set_experiment(experiment)" ] @@ -71,12 +82,24 @@ { "cell_type": "code", "execution_count": null, - "id": "a36eaef2", + "id": "e9a26124", "metadata": {}, "outputs": [], "source": [ - "model_params = run_config['model_params']\n", - "rggb = model_params['rggb']" + "\n", + "RUN_ID = \"425568ac95d340d7a59c624233269207\" \n", + "ARTIFACT_PATH = run_config['run_path']\n", + "\n", + "model_uri = f\"runs:/{RUN_ID}/{ARTIFACT_PATH}\"\n", + "\n", + "try:\n", + " model = mlflow.pytorch.load_model(model_uri)\n", + " model.eval()\n", + " print(f\"Model successfully loaded from MLflow URI: {model_uri}\")\n", + " \n", + "\n", + "except Exception as e:\n", + " print(f\"Error loading model from MLflow: {e}\")" ] }, { @@ -101,19 +124,6 @@ "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "086fbb3a", - "metadata": {}, - "outputs": [], - "source": [ - "model = make_full_model_RGGB(model_params, model_name=None)\n", - "model = model.to(device)\n", - "\n", - "params = {**run_config, **model_params}" - ] - }, { "cell_type": "code", "execution_count": null, @@ -151,7 +161,6 @@ "source": [ "with mlflow.start_run(run_name=run_config['run_name']) as run:\n", "\n", - " mlflow.log_params(params)\n", " for epoch in range(num_epochs):\n", " train_one_epoch(epoch, model, optimizer, train_loader, device, loss_fn, clipping, \n", " log_interval = 10, sleep=0.0, rggb=rggb, max_batches=0)\n", From 7429e26826248c4f341b469ac2e1fae42e1c7906 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Sun, 26 Oct 2025 18:51:44 -0400 Subject: [PATCH 24/56] Update to trace making script --- 3_make_script.ipynb | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/3_make_script.ipynb b/3_make_script.ipynb index 624036e..171ad30 100644 --- a/3_make_script.ipynb +++ b/3_make_script.ipynb @@ -25,22 +25,27 @@ { "cell_type": "code", "execution_count": null, - "id": "72145448", + "id": "5f499abe", "metadata": {}, "outputs": [], "source": [ - "run_config" + "run_config = load_config()\n", + "output = Path(run_config['script_path']) / f\"{run_config['run_path']}.pt\"\n", + "experiment = run_config['mlflow_experiment']\n", + "mlflow_path = run_config['mlflow_path']\n", + "rggb = True\n", + "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", + "mlflow.set_experiment(experiment)" ] }, { "cell_type": "code", "execution_count": null, - "id": "5f499abe", + "id": "0655be14", "metadata": {}, "outputs": [], "source": [ - "run_config = load_config()\n", - "output = Path(run_config['script_path']) / f\"{run_config['run_path']}.pt\"" + "output" ] }, { @@ -50,7 +55,8 @@ "metadata": {}, "outputs": [], "source": [ - "RUN_ID = \"66f11f4639f24e9ea75b4c953147be15\" \n", + "\n", + "RUN_ID = \"10df0d1d6eba4c6b887086d69ab390a7\" \n", "ARTIFACT_PATH = run_config['run_path']\n", "\n", "model_uri = f\"runs:/{RUN_ID}/{ARTIFACT_PATH}\"\n", @@ -58,7 +64,7 @@ "try:\n", " model = mlflow.pytorch.load_model(model_uri)\n", " model.eval()\n", - " model.to('cpu')\n", + " model = model.to('cpu')\n", " print(f\"Model successfully loaded from MLflow URI: {model_uri}\")\n", " \n", "\n", From c2e5f968b2d709e9461377a9e6a56f565ab85d9a Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Sun, 26 Oct 2025 18:52:04 -0400 Subject: [PATCH 25/56] Added additional losses to test --- src/training/losses/ShadowAwareLoss.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/training/losses/ShadowAwareLoss.py b/src/training/losses/ShadowAwareLoss.py index 447fa0a..3714932 100644 --- a/src/training/losses/ShadowAwareLoss.py +++ b/src/training/losses/ShadowAwareLoss.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from pytorch_msssim import ms_ssim - +from src.training.losses.CombinedPerceptualLoss import VGGPerceptualLoss class ShadowAwareLoss(nn.Module): def __init__(self, @@ -13,6 +13,7 @@ def __init__(self, vgg_loss_weight=0.0, apply_gamma_fn=None, vgg_feature_extractor=None, + percept_loss_weight = 0, device=None): """ Shadow-aware image restoration loss. @@ -35,6 +36,8 @@ def __init__(self, self.apply_gamma_fn = apply_gamma_fn self.vfe = vgg_feature_extractor self.device = device + self.percept_loss_weight = percept_loss_weight + self.VGGPerceptualLoss = VGGPerceptualLoss() if device is not None: self.to(device) @@ -77,12 +80,16 @@ def forward(self, pred, target): target_features = self.vfe(target) vgg_loss_val = nn.functional.mse_loss(pred_features[0], target_features[0]) + percept_loss = 0 + if self.percept_loss_weight: + percept_loss = self.VGGPerceptualLoss(pred, target) # Combine weighted terms total_loss = ( self.l1_weight * l1 + self.ssim_weight * ssim + self.tv_weight * tv + - self.vgg_loss_weight * vgg_loss_val + self.vgg_loss_weight * vgg_loss_val + + self.percept_loss_weight * percept_loss ) return total_loss From 0a8a86f24bda42685a01ae51faa506f0caf95bf3 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Sun, 26 Oct 2025 18:52:22 -0400 Subject: [PATCH 26/56] Updating align script to preserve sharpness --- src/training/align_images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/training/align_images.py b/src/training/align_images.py index f24ad0e..ab0a56a 100644 --- a/src/training/align_images.py +++ b/src/training/align_images.py @@ -114,7 +114,7 @@ def safe_gray(img): -def apply_alignment(img, warp_params, interpolation=cv2.INTER_LINEAR): +def apply_alignment(img, warp_params, interpolation=cv2.INTER_LANCZOS4): """ Applies a previously estimated affine warp to an image. warp_params: dict with keys m00..m12 or a 2x3 numpy array. From b079f703a58582add5498f0d4053b18069d0267c Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Sun, 26 Oct 2025 18:52:45 -0400 Subject: [PATCH 27/56] Added modules being tested --- src/Restorer/Cond_NAF.py | 289 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 285 insertions(+), 4 deletions(-) diff --git a/src/Restorer/Cond_NAF.py b/src/Restorer/Cond_NAF.py index 3c0074c..efebe4a 100644 --- a/src/Restorer/Cond_NAF.py +++ b/src/Restorer/Cond_NAF.py @@ -92,6 +92,18 @@ def forward(self, x): out = x * self.attention(attn) return out + +class CondFuser(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + x1, x2 = x.chunk(2, dim=1) + return x1 + x2 + class CondFuserAdd(nn.Module): def __init__(self, chan, cond_chan=1): super().__init__() @@ -113,6 +125,39 @@ def forward(self, x1, x2, cond): return x1 * x2 +class CondFuserV3(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + self.spa = nn.Conv2d( + in_channels=chan * 2, + out_channels=1, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + spa = torch.sigmoid(self.spa(x)) + + x1, x2 = x.chunk(2, dim=1) + return x1 * spa + x2 * (1 - spa) + +class CondFuserV4(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + self.pw = nn.Conv2d(chan * 2, chan, 1, stride=1, padding=0, groups=1) + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + x = self.pw(x) + return x + class NAFBlock0(nn.Module): def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): @@ -213,6 +258,173 @@ def forward(self, input): x = self.dropout2(x) return (y + x * self.gamma, cond) + + +class NAFBlock0_learned_norm(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.sca_mul = ConditionedChannelAttention(c, cond_chans) + self.sca_add = ConditionedChannelAttention(c, cond_chans) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + normed = self.norm2(y) + + # Input mediated channel attention, obstensibly to mitigate the effects of group norm on flat scenes + x = (1 + self.sca_mul(inp, cond)) * normed + self.sca_add(inp, cond) + + x = self.conv4(x) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + + + return (y + x * self.gamma, cond) + + +import torch.nn.functional as F + +class SwiGLU(nn.Module): + + def __init__(self, input_dim, hidden_dim, dropout=0.1): + super().__init__() + self.w1 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1) + self.w2 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1) + self.w3 = nn.Conv2d(hidden_dim, input_dim, 1, 1, 0, 1) + + def forward(self, x): + gate = F.silu(self.w1(x)) + value = self.w2(x) + x = gate * value + + x = self.w3(x) + return x + +class AttnBlock(nn.Module): + def __init__(self, c, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + + self.dw = nn.Conv2d( + in_channels=c, + out_channels=c, + kernel_size=3, + padding=1, + stride=1, + groups=c, + bias=True, + ) + self.nka = NKA(c) + + self.sca = ConditionedChannelAttention(c, cond_chans) + + self.norm = nn.GroupNorm(1, c) + + self.swiglu = SwiGLU(c, int(c * FFN_Expand)) + self.alpha = nn.Parameter(torch.zeros(1, c, 1, 1)) + self.beta = nn.Parameter(torch.zeros(1, c, 1, 1)) + + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = self.dw(inp) + x = self.nka(x) + x = self.sca(x, cond) * x + y = self.norm(inp + self.alpha * x ) + + + x = self.swiglu(y) + x = y + self.beta * x + return (x, cond) class CondSEBlock(nn.Module): @@ -234,6 +446,37 @@ def forward(self, x, conditioning): return x * y.expand_as(x) + +class ConditioningCNN(nn.Module): + def __init__(self, in_channels=4, num_logits=128): + """ + Args: + in_channels (int): Number of input channels (e.g., 3 for RGB). + num_logits (int): The desired size of the output 1D logit vector. + """ + super().__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True), + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True) + ) + + self.logit_head = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(256, num_logits) + ) + def forward(self, x): + x = self.encoder(x) + x = self.logit_head(x) + return x + class Restorer(nn.Module): def __init__( self, @@ -250,15 +493,40 @@ def __init__( drop_out_rate_increment=0.0, rggb = False, use_CondFuserV2 = False, - use_add = False + use_add = False, + use_CondFuserV3 = False, + use_CondFuserV4 = False, + use_attnblock = False, + use_NAFBlock0_learned_norm=False, + use_cond_net = False, + cond_net_num = 32, + use_input_stats=False, ): super().__init__() + if use_attnblock: + block = AttnBlock + elif use_NAFBlock0_learned_norm: + block = NAFBlock0_learned_norm + else: + block = NAFBlock0 + width = chans[0] self.expand_dims = expand_dims self.conditioning_gen = nn.Sequential( nn.Linear(cond_input, 64), nn.ReLU(), nn.Dropout(drop_out_rate), nn.Linear(64, cond_output), ) + + + self.use_cond_net = use_cond_net + if use_cond_net: + self.cond_net = ConditioningCNN(in_channels=in_channels, num_logits=cond_net_num) + cond_output = cond_output + cond_net_num + + self.use_input_stats = use_input_stats + if use_input_stats: + cond_output = cond_output + in_channels * 2 + self.rggb = rggb if not rggb: self.intro = nn.Conv2d( @@ -320,7 +588,7 @@ def __init__( self.encoders.append( nn.Sequential( *[ - NAFBlock0(current_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + block(current_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) for _ in range(num) ] ) @@ -330,7 +598,7 @@ def __init__( self.middle_blks = nn.Sequential( *[ - NAFBlock0(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + block(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) for _ in range(middle_blk_num) ] ) @@ -349,13 +617,17 @@ def __init__( self.merges.append(CondFuserV2(next_chan, cond_chan=cond_output)) elif use_add: self.merges.append(CondFuserAdd(next_chan, cond_chan=cond_output)) + elif use_CondFuserV3: + self.merges.append(CondFuserV3(next_chan, cond_chan=cond_output)) + elif use_CondFuserV4: + self.merges.append(CondFuserV4(next_chan, cond_chan=cond_output)) else: self.merges.append(CondFuser(next_chan, cond_chan=cond_output)) self.decoders.append( nn.Sequential( *[ - NAFBlock0(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + block(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) for _ in range(num) ] ) @@ -368,6 +640,15 @@ def forward(self, inp, cond_in): # Conditioning: cond = self.conditioning_gen(cond_in) + if self.use_cond_net: + extra_cond = self.cond_net(inp) + cond = torch.cat([cond, extra_cond], dim=1) + if self.use_input_stats: + mu = inp.mean((2,3), keepdim=True) + var = (inp - mu).pow(2).mean((2,3), keepdim=False) + mu = mu.squeeze(-1).squeeze(-1) + cond = torch.cat([cond, mu, var], dim=1) + B, C, H, W = inp.shape if self.rggb: H = 2 * H From 80ce72efecb2aa17870120e01e875fd549f4875e Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Sun, 26 Oct 2025 18:52:56 -0400 Subject: [PATCH 28/56] Up to date config --- config.yaml | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/config.yaml b/config.yaml index 63d97cd..4dfd53a 100644 --- a/config.yaml +++ b/config.yaml @@ -7,7 +7,7 @@ cropped_raw_subdir: /Volumes/EasyStore/RAWNIND/Cropped_Raw cropped_raw_size: 2000 align_csv: align_data.csv secondary_align_csv: align_phase_v2.csv -script_path: traces_and_weights +script_path: /Volumes/EasyStore/models/traces mlflow_path: /Volumes/EasyStore/models/mlfow # --- Training Params --- @@ -22,13 +22,14 @@ num_epochs_finetuning: 20 val_split: 0.2 random_seed: 42 + # --- Experiment Settings --- experiment: NAF_test mlflow_experiment: NAFNet_variations # --- Run Configuration ---: -run_name: NAF_deep_CondFuserAdd -run_path: NAF_deep_CondFuserAdd +run_name: NAF_use_input_stats +run_path: NAF_deep_test_align model_params: chans: [32, 64, 128, 256, 256, 256] enc_blk_nums: [2, 2, 2, 3, 4] @@ -39,7 +40,12 @@ model_params: out_channels: 3 rggb: True use_CondFuserV2: False - use_add: True + use_add: False + use_CondFuserV3: False + use_attnblock: False + use_CondFuserV4: False + use_NAFBlock0_learned_norm: True + use_input_stats: True # --- Loss Configureation ---: alpha: 0.2 @@ -47,4 +53,5 @@ beta: 5.0 l1_weight: 0.16 ssim_weight: 0.84 tv_weight: 0.0 -vgg_loss_weight: 0 \ No newline at end of file +vgg_loss_weight: 0 +percept_loss_weight: 0 \ No newline at end of file From bba2c7041eb4497c8be7dbdaf3dc595b6e817f00 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Sun, 26 Oct 2025 18:53:43 -0400 Subject: [PATCH 29/56] Updated gitignore --- .gitignore | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index fd9e85c..10f7995 100644 --- a/.gitignore +++ b/.gitignore @@ -194,5 +194,9 @@ cython_debug/ .cursorindexingignore -# MLrn -mlruns/ \ No newline at end of file +# custom ignores +mlruns/ +.DS_store +*.png +*.jpeg +*.csv \ No newline at end of file From 178d089e5dc9542119a6e3e4014e378c41491606 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Tue, 28 Oct 2025 13:31:27 -0400 Subject: [PATCH 30/56] Updating cropped raw to better reflect the real raw data. --- 0_produce_small_dataset_raw.ipynb | 13 +++++++------ src/training/SmallRawDatasetNumpy.py | 7 +++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/0_produce_small_dataset_raw.ipynb b/0_produce_small_dataset_raw.ipynb index 3ae0532..9d4ca62 100644 --- a/0_produce_small_dataset_raw.ipynb +++ b/0_produce_small_dataset_raw.ipynb @@ -165,11 +165,12 @@ " bottom = top + crop_size\n", "\n", " im = rh.apply_colorspace_transform(dims=(left, right, top, bottom), colorspace=colorspace)\n", - " im_scaled = im * 65535.0\n", - " im_clipped = np.clip(im_scaled, 0.0, 65535.0)\n", + " im = im.astype(np.float16)\n", + " # im_scaled = im * 65535.0\n", + " # im_clipped = np.clip(im_scaled, 0.0, 65535.0)\n", "\n", - " im_uint16 = im_clipped.astype(np.uint16)\n", - " return im_uint16" + " # im_uint16 = im_clipped.astype(np.uint16)\n", + " return im" ] }, { @@ -194,10 +195,10 @@ " for noisy, gt in image_pairs:\n", " try:\n", " noisy_bayer = get_file(f'{raw_path}/{noisy}')\n", - " noisy_path = outpath / (noisy + \".u16.raw\")\n", + " noisy_path = outpath / (noisy + \".f16.raw\")\n", " noisy_bayer.tofile(noisy_path)\n", "\n", - " gt_path = outpath / (gt + \".u16.raw\")\n", + " gt_path = outpath / (gt + \".f16.raw\")\n", " if not os.path.exists(gt_path):\n", " gt_bayer = get_file(f'{raw_path}/{gt}')\n", " gt_bayer.tofile(gt_path)\n", diff --git a/src/training/SmallRawDatasetNumpy.py b/src/training/SmallRawDatasetNumpy.py index 806d67c..f624699 100644 --- a/src/training/SmallRawDatasetNumpy.py +++ b/src/training/SmallRawDatasetNumpy.py @@ -25,7 +25,7 @@ def __init__(self, path, csv, crop_size=180, buffer=10, validation=False, run_al self.coordinate_iso = 6400 self.validation=validation self.run_align = run_align - self.dtype = np.uint16 + self.dtype = np.float16 self.dimensions = dimensions def __len__(self): return len(self.df) @@ -46,8 +46,8 @@ def __getitem__(self, idx): - gt_image = gt_image/65535 - bayer_data = bayer_data/65535 + gt_image = gt_image + bayer_data = bayer_data gt_image = demosaicing_CFA_Bayer_Malvar2004(gt_image) demosaiced_noisy = demosaicing_CFA_Bayer_Malvar2004(bayer_data) @@ -58,7 +58,6 @@ def __getitem__(self, idx): aligned = aligned / 255 else: aligned = apply_alignment(gt_image, row.to_dict()) - h, w, _ = gt_image.shape #Crop images From 567545e60f678142c30e3b6b11aacdae950d8ec9 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Tue, 28 Oct 2025 16:42:53 -0400 Subject: [PATCH 31/56] Censored offset func --- src/training/censored_fit.py | 96 ++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 src/training/censored_fit.py diff --git a/src/training/censored_fit.py b/src/training/censored_fit.py new file mode 100644 index 0000000..6b41e4d --- /dev/null +++ b/src/training/censored_fit.py @@ -0,0 +1,96 @@ +import numpy as np +from scipy.stats import norm + +def censored_linear_fit_twosided(x, y, clip_low=None, clip_high=None, max_iter=200, tol=1e-6, include_offset=True): + """ + Fit y ≈ a + b*x + ε, ε ~ N(0, σ²) under two-sided censoring: + clip_low ≤ y_true ≤ clip_high + Observed y are clipped to [clip_low, clip_high]. + Returns (a, b, sigma) estimated via EM. + + Parameters + ---------- + x, y : array_like + Input data. + clip_low, clip_high : float or None + Lower/upper clip levels. Can be None for one-sided clipping. + max_iter : int + Maximum EM iterations. + tol : float + Relative tolerance for convergence. + include_offset: bool + Compute linear fit with offset (y = b * x + a) + Returns + ------- + a, b, sigma : floats + Estimated regression parameters. + """ + + x = np.asarray(x).ravel() + y = np.asarray(y).ravel() + mask = np.isfinite(x) & np.isfinite(y) + x, y = x[mask], y[mask] + + n = len(x) + if n < 3: + raise ValueError("Not enough data points.") + + # --- initial guess (ordinary least squares) --- + if include_offset: + A = np.vstack([np.ones_like(x), x]).T + else: + A = np.vstack([x]).T + coef, *_ = np.linalg.lstsq(A, y, rcond=None) + if include_offset: + a, b = coef + else: + a, b = 0, coef[0] + sigma = np.std(y - (a + b*x)) + + for _ in range(max_iter): + mu = a + b*x + y_exp = y.copy() + + # Handle right-censoring (high clip) + if clip_high is not None: + high_mask = y >= clip_high - 1e-12 + if np.any(high_mask): + z = (clip_high - mu[high_mask]) / sigma + Phi = norm.cdf(z) + phi = norm.pdf(z) + one_minus_Phi = 1.0 - Phi + lambda_ = np.zeros_like(z) + valid = one_minus_Phi > 1e-15 + lambda_[valid] = phi[valid] / one_minus_Phi[valid] + y_exp[high_mask] = mu[high_mask] + sigma * lambda_ + + # Handle left-censoring (low clip) + if clip_low is not None: + low_mask = y <= clip_low + 1e-12 + if np.any(low_mask): + z = (clip_low - mu[low_mask]) / sigma + Phi = norm.cdf(z) + phi = norm.pdf(z) + lambda_ = np.zeros_like(z) + valid = Phi > 1e-15 + lambda_[valid] = -phi[valid] / Phi[valid] + y_exp[low_mask] = mu[low_mask] + sigma * lambda_ + + # M-step: re-fit with imputed expectations + if include_offset: + A = np.vstack([np.ones_like(x), x]).T + else: + A = np.vstack([x]).T + coef, *_ = np.linalg.lstsq(A, y, rcond=None) + if include_offset: + a_new, b_new = coef + else: + a_new, b_new = 0, coef[0] + sigma_new = np.std(y_exp - (a_new + b_new*x)) + + if np.allclose([a, b, sigma], [a_new, b_new, sigma_new], rtol=tol, atol=tol): + break + a, b, sigma = a_new, b_new, sigma_new + + return a, b, sigma + From c2bad0b7da33d61b135ee5ff3392de51d238807b Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Wed, 29 Oct 2025 18:53:59 -0400 Subject: [PATCH 32/56] Update to save dng --- 0_produce_small_dataset_raw.ipynb | 659 +++++++++++++++++++++++++++++- 1 file changed, 656 insertions(+), 3 deletions(-) diff --git a/0_produce_small_dataset_raw.ipynb b/0_produce_small_dataset_raw.ipynb index 9d4ca62..0637156 100644 --- a/0_produce_small_dataset_raw.ipynb +++ b/0_produce_small_dataset_raw.ipynb @@ -133,6 +133,193 @@ "pair_file_list = pair_images_by_scene(file_list)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "206ab100", + "metadata": {}, + "outputs": [], + "source": [ + "def get_file(impath, crop_size=crop_size):\n", + " rh = RawHandler(impath)\n", + " \n", + " width, height = rh.raw.shape\n", + "\n", + " # Check if the image is large enough to be cropped.\n", + " if width < crop_size or height < crop_size:\n", + " im = rh.apply_colorspace_transform(colorspace=colorspace)\n", + " else:\n", + " # Calculate the coordinates for the center crop.\n", + " left = (width - crop_size) // 2\n", + " top = (height - crop_size) // 2\n", + "\n", + " # Ensure the top-left corner is on an even pixel coordinate for Bayer alignment.\n", + " if left % 2 != 0:\n", + " left -= 1\n", + " if top % 2 != 0:\n", + " top -= 1\n", + " \n", + " # Calculate the bottom-right corner based on the adjusted top-left corner.\n", + " # Since crop_size is even, right and bottom will also be even.\n", + " right = left + crop_size\n", + " bottom = top + crop_size\n", + " return rh.raw[left:right, top:bottom], rh" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c29f0629", + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5b63be2", + "metadata": {}, + "outputs": [], + "source": [ + "from pidng.core import RAW2DNG, DNGTags, Tag\n", + "from pidng.defs import *\n", + "\n", + "def get_ratios(string, rh):\n", + " return [x.as_integer_ratio() for x in rh.full_metadata[string].values]\n", + "\n", + "\n", + "def rational_wb(rh, denominator=1000):\n", + " wb = np.array(rh.core_metadata.camera_white_balance)\n", + " numerator_matrix = np.round(wb * denominator).astype(int)\n", + " return [[num, denominator] for num in numerator_matrix]\n", + "def convert_ccm_to_rational(matrix_3x3, denominator=10000):\n", + "\n", + " numerator_matrix = np.round(matrix_3x3 * denominator).astype(int)\n", + " numerators_flat = numerator_matrix.flatten()\n", + " ccm_rational = [[num, denominator] for num in numerators_flat]\n", + " \n", + " return ccm_rational\n", + "\n", + "\n", + "def get_as_shot_neutral(rh, denominator=10000):\n", + "\n", + " cam_mul = rh.core_metadata.camera_white_balance\n", + " \n", + " if cam_mul[0] == 0 or cam_mul[2] == 0:\n", + " return [[denominator, denominator], [denominator, denominator], [denominator, denominator]]\n", + "\n", + " r_neutral = cam_mul[1] / cam_mul[0]\n", + " g_neutral = 1.0 \n", + " b_neutral = cam_mul[1] / cam_mul[2]\n", + "\n", + " return [\n", + " [int(r_neutral * denominator), denominator],\n", + " [int(g_neutral * denominator), denominator],\n", + " [int(b_neutral * denominator), denominator],\n", + " ]\n", + "\n", + "\n", + "def to_dng(uint_img, rh, filepath):\n", + " width = uint_img.shape[1]\n", + " height = uint_img.shape[0]\n", + " bpp = 16\n", + "\n", + " exposures = get_ratios('EXIF ExposureTime', rh)\n", + " fnumber = get_ratios('EXIF FNumber', rh)\n", + " ExposureBiasValue = get_ratios('EXIF ExposureBiasValue', rh) \n", + " FocalLength = get_ratios('EXIF FocalLength', rh) \n", + " ccm1 = convert_ccm_to_rational(rh.core_metadata.rgb_xyz_matrix[:3, :])\n", + " t = DNGTags()\n", + " t.set(Tag.ImageWidth, width)\n", + " t.set(Tag.ImageLength, height)\n", + " t.set(Tag.TileWidth, width)\n", + " t.set(Tag.TileLength, height)\n", + " t.set(Tag.BitsPerSample, bpp)\n", + "\n", + " t.set(Tag.SamplesPerPixel, 1) \n", + " t.set(Tag.PlanarConfiguration, 1) \n", + "\n", + " t.set(Tag.TileWidth, width)\n", + " t.set(Tag.TileLength, height)\n", + " t.set(Tag.Orientation, rh.full_metadata['Image Orientation'].values[0])\n", + " t.set(Tag.PhotometricInterpretation, PhotometricInterpretation.Color_Filter_Array)\n", + " t.set(Tag.CFARepeatPatternDim, [2,2])\n", + " t.set(Tag.CFAPattern, CFAPattern.RGGB)\n", + " bl = rh.core_metadata.black_level_per_channel\n", + " t.set(Tag.BlackLevelRepeatDim, [2,2])\n", + " t.set(Tag.BlackLevel, bl)\n", + " t.set(Tag.WhiteLevel, rh.core_metadata.white_level)\n", + "\n", + " t.set(Tag.BitsPerSample, bpp)\n", + "\n", + " t.set(Tag.ColorMatrix1, ccm1)\n", + " t.set(Tag.CalibrationIlluminant1, CalibrationIlluminant.D65)\n", + " wb = get_as_shot_neutral(rh)\n", + " t.set(Tag.AsShotNeutral, wb)\n", + " t.set(Tag.BaselineExposure, [[0,100]])\n", + " t.set(Tag.Make, rh.full_metadata['Image Make'].values)\n", + " t.set(Tag.Model, rh.full_metadata['Image Model'].values)\n", + "\n", + "\n", + "\n", + " t.set(Tag.FocalLength, FocalLength)\n", + " t.set(Tag.EXIFPhotoLensModel, rh.full_metadata['EXIF LensModel'].values)\n", + " t.set(Tag.ExposureBiasValue, ExposureBiasValue)\n", + " t.set(Tag.ExposureTime, exposures)\n", + " t.set(Tag.FNumber, fnumber)\n", + " t.set(Tag.PhotographicSensitivity, rh.full_metadata['EXIF ISOSpeedRatings'].values)\n", + " t.set(Tag.DNGVersion, DNGVersion.V1_4)\n", + " t.set(Tag.DNGBackwardVersion, DNGVersion.V1_2)\n", + " t.set(Tag.PreviewColorSpace, PreviewColorSpace.Adobe_RGB)\n", + "\n", + " r = RAW2DNG()\n", + "\n", + " r.options(t, path=\"\", compress=False)\n", + "\n", + " r.convert(uint_img, filename=filepath)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07151d60", + "metadata": {}, + "outputs": [], + "source": [ + "for key in tqdm(pair_file_list.keys()):\n", + " image_pairs = pair_file_list[key]\n", + " for noisy, gt in image_pairs:\n", + " try:\n", + " bayer, rh = get_file(f'{raw_path}/{noisy}')\n", + " noisy_path = outpath / (noisy)\n", + " to_dng(bayer, rh, str(noisy_path))\n", + "\n", + " gt_path = outpath / (gt)\n", + " if not os.path.exists(gt_path):\n", + " bayer, rh = get_file(f'{raw_path}/{gt}')\n", + " to_dng(bayer, rh, str(gt_path))\n", + " except:\n", + " print(f\"Skipping {raw_path}/{noisy}, {raw_path}/{gt}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06af6496", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6dcdae9b", + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -147,7 +334,7 @@ "\n", " # Check if the image is large enough to be cropped.\n", " if width < crop_size or height < crop_size:\n", - " im = rh.apply_colorspace_transform( colorspace=colorspace)\n", + " im = rh.apply_colorspace_transform(colorspace=colorspace)\n", " else:\n", " # Calculate the coordinates for the center crop.\n", " left = (width - crop_size) // 2\n", @@ -170,7 +357,7 @@ " # im_clipped = np.clip(im_scaled, 0.0, 65535.0)\n", "\n", " # im_uint16 = im_clipped.astype(np.uint16)\n", - " return im" + " return im, rh.raw[left:right, top:bottom], rh" ] }, { @@ -203,8 +390,474 @@ " gt_bayer = get_file(f'{raw_path}/{gt}')\n", " gt_bayer.tofile(gt_path)\n", " except:\n", - " print(f\"Skipping {raw_path}/{noisy}, {raw_path}/{gt}\")" + " print(f\"Skipping {raw_path}/{noisy}, {raw_path}/{gt}\")\n", + " break\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "264b6831", + "metadata": {}, + "outputs": [], + "source": [ + "dir(Tag)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "225e097f", + "metadata": {}, + "outputs": [], + "source": [ + "from pidng.core import RAW2DNG, DNGTags, Tag\n", + "from pidng.defs import *\n", + "\n", + "def get_ratios(string, rh):\n", + " return [x.as_integer_ratio() for x in rh.full_metadata[string].values]\n", + "\n", + "\n", + "def rational_wb(rh, denominator=1000):\n", + " wb = np.array(rh.core_metadata.camera_white_balance)\n", + " numerator_matrix = np.round(wb * denominator).astype(int)\n", + " return [[num, denominator] for num in numerator_matrix]\n", + "def convert_ccm_to_rational(matrix_3x3, denominator=10000):\n", + "\n", + " numerator_matrix = np.round(matrix_3x3 * denominator).astype(int)\n", + " numerators_flat = numerator_matrix.flatten()\n", + " ccm_rational = [[num, denominator] for num in numerators_flat]\n", + " \n", + " return ccm_rational\n", + "\n", + "\n", + "def get_as_shot_neutral(rh, denominator=10000):\n", + "\n", + " # Get multipliers [R, G1, B, G2]\n", + " cam_mul = rh.core_metadata.camera_white_balance\n", + " \n", + " # Check for zero multipliers to avoid division by zero\n", + " if cam_mul[0] == 0 or cam_mul[2] == 0:\n", + " # Fallback to [1, 1, 1] if multipliers are bad\n", + " return [[denominator, denominator], [denominator, denominator], [denominator, denominator]]\n", + "\n", + " # Calculate inverse multipliers normalized to G (cam_mul[1])\n", + " # DNG spec AsShotNeutral = [1/R_scale, 1/G_scale, 1/B_scale]\n", + " # where G_scale = 1.0. This means:\n", + " # R_scale = R_mult / G_mult\n", + " # B_scale = B_mult / G_mult\n", + " # So 1/R_scale = G_mult / R_mult\n", + " \n", + " r_neutral = cam_mul[1] / cam_mul[0]\n", + " g_neutral = 1.0 # G is always 1.0\n", + " b_neutral = cam_mul[1] / cam_mul[2]\n", + "\n", + " return [\n", + " [int(r_neutral * denominator), denominator],\n", + " [int(g_neutral * denominator), denominator],\n", + " [int(b_neutral * denominator), denominator],\n", + " ]\n", + "\n", + "\n", + "def to_dng(uint_img, rh, filepath):\n", + " uint_img = np.ascontiguousarray(uint_img)\n", + " width = uint_img.shape[1]\n", + " height = uint_img.shape[0]\n", + " bpp = 16\n", + "\n", + " exposures = get_ratios('EXIF ExposureTime', rh)\n", + " fnumber = get_ratios('EXIF FNumber', rh)\n", + " ExposureBiasValue = get_ratios('EXIF ExposureBiasValue', rh) \n", + " FocalLength = get_ratios('EXIF FocalLength', rh) \n", + " ccm1 = convert_ccm_to_rational(rh.core_metadata.rgb_xyz_matrix[:3, :])\n", + " t = DNGTags()\n", + " t.set(Tag.ImageWidth, width)\n", + " t.set(Tag.ImageLength, height)\n", + " t.set(Tag.TileWidth, width)\n", + " t.set(Tag.TileLength, height)\n", + " t.set(Tag.BitsPerSample, bpp)\n", + "\n", + " t.set(Tag.SamplesPerPixel, 1) \n", + " t.set(Tag.PlanarConfiguration, 1) \n", + "\n", + " t.set(Tag.TileWidth, width)\n", + " t.set(Tag.TileLength, height)\n", + " t.set(Tag.Orientation, rh.full_metadata['Image Orientation'].values[0])\n", + " t.set(Tag.PhotometricInterpretation, PhotometricInterpretation.Color_Filter_Array)\n", + " t.set(Tag.CFARepeatPatternDim, [2,2])\n", + " t.set(Tag.CFAPattern, CFAPattern.RGGB)\n", + " bl = rh.core_metadata.black_level_per_channel\n", + " t.set(Tag.BlackLevelRepeatDim, [2,2])\n", + " t.set(Tag.BlackLevel, bl)\n", + " t.set(Tag.WhiteLevel, rh.core_metadata.white_level)\n", + "\n", + " t.set(Tag.BitsPerSample, bpp)\n", + "\n", + " t.set(Tag.ColorMatrix1, ccm1)\n", + " t.set(Tag.CalibrationIlluminant1, CalibrationIlluminant.D65)\n", + " wb = get_as_shot_neutral(rh)\n", + " print(wb)\n", + " t.set(Tag.AsShotNeutral, wb)\n", + " t.set(Tag.BaselineExposure, [[0,100]])\n", + " t.set(Tag.Make, rh.full_metadata['Image Make'].values)\n", + " t.set(Tag.Model, rh.full_metadata['Image Model'].values)\n", + "\n", + "\n", + "\n", + " t.set(Tag.FocalLength, FocalLength)\n", + " t.set(Tag.EXIFPhotoLensModel, rh.full_metadata['EXIF LensModel'].values)\n", + " t.set(Tag.ExposureBiasValue, ExposureBiasValue)\n", + " t.set(Tag.ExposureTime, exposures)\n", + " t.set(Tag.FNumber, fnumber)\n", + " t.set(Tag.PhotographicSensitivity, rh.full_metadata['EXIF ISOSpeedRatings'].values)\n", + " t.set(Tag.DNGVersion, DNGVersion.V1_4)\n", + " t.set(Tag.DNGBackwardVersion, DNGVersion.V1_2)\n", + " t.set(Tag.PreviewColorSpace, PreviewColorSpace.Adobe_RGB)\n", + "\n", + " r = RAW2DNG()\n", + "\n", + " r.options(t, path=\"\", compress=False)\n", + "\n", + " r.convert(uint_img, filename=filepath)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26d424c6", + "metadata": {}, + "outputs": [], + "source": [ + "Tag.BlackLevel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "072cf0d2", + "metadata": {}, + "outputs": [], + "source": [ + "im, raw, rh = get_file(f'{raw_path}/{noisy}')\n", + "width, height = rh.raw.shape\n", + "\n", + "# Check if the image is large enough to be cropped.\n", + "if width < crop_size or height < crop_size:\n", + " im = rh.apply_colorspace_transform( colorspace=colorspace)\n", + "else:\n", + " # Calculate the coordinates for the center crop.\n", + " left = (width - crop_size) // 2\n", + " top = (height - crop_size) // 2\n", + "\n", + " # Ensure the top-left corner is on an even pixel coordinate for Bayer alignment.\n", + " if left % 2 != 0:\n", + " left -= 1\n", + " if top % 2 != 0:\n", + " top -= 1\n", + " \n", + " # Calculate the bottom-right corner based on the adjusted top-left corner.\n", + " # Since crop_size is even, right and bottom will also be even.\n", + " right = left + crop_size\n", + " bottom = top + crop_size" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9010156", + "metadata": {}, + "outputs": [], + "source": [ + "rh.raw.max()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28809a03", + "metadata": {}, + "outputs": [], + "source": [ + "im = rh.as_rgb(dims=(left, right, top, bottom))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5137cf23", + "metadata": {}, + "outputs": [], + "source": [ + "rh.core_metadata.raw_pattern" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "562b2efd", + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(im.transpose(1, 2, 0)**(1/2.2))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5047b5e9", + "metadata": {}, + "outputs": [], + "source": [ + "rh.raw[left:right, top:bottom]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4faf9bf4", + "metadata": {}, + "outputs": [], + "source": [ + "rhdng.raw" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66de04c9", + "metadata": {}, + "outputs": [], + "source": [ + "rh.core_metadata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1c1bc06", + "metadata": {}, + "outputs": [], + "source": [ + "to_dng(raw, rh, \"test_camera_WB_D65_compress\")\n", + "rhdng = RawHandler(\"test_camera_WB_D65_compress.dng\")\n", + "rhdng.core_metadata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6509b267", + "metadata": {}, + "outputs": [], + "source": [ + "rhdng.raw.max()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5831855f", + "metadata": {}, + "outputs": [], + "source": [ + "rh.raw[left:right, top:bottom].max()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c9fe790", + "metadata": {}, + "outputs": [], + "source": [ + "2.80898881*1024" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7943162", + "metadata": {}, + "outputs": [], + "source": [ + "1.80864525*1024" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4a368fa", + "metadata": {}, + "outputs": [], + "source": [ + "im2 = rhdng.as_rgb()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4c73bb2", + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(im2.transpose(1, 2, 0)**(1/2.2))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f25ed2f", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.load_config import load_config\n", + "from src.training.SmallRawDatasetNumpy import SmallRawDatasetNumpy\n", + "from src.training.censored_fit import censored_linear_fit_twosided\n", + "\n", + "run_config = load_config()\n", + "dataset_path = Path(run_config['cropped_raw_subdir'])\n", + "align_csv = dataset_path / run_config['secondary_align_csv']\n", + "crop_size = 1500" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c86b449", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = SmallRawDatasetNumpy(dataset_path, align_csv, crop_size=crop_size)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1e363d7", + "metadata": {}, + "outputs": [], + "source": [ + "output = dataset[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2f39a99", + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(output['aligned'].permute(1, 2, 0))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ad377b4", + "metadata": {}, + "outputs": [], + "source": [ + "rhdng = RawHandler(\"test_camera_WB_D65_compress.dng\")\n", + "rh = RawHandler('/Volumes/EasyStore/RAWNIND/Bayer_MuseeL-sol-A7C-brighter_ISO100_sha1=18eaa9931d9a0f6f0511552ef6bf2fd040d82878.arw')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3269995e", + "metadata": {}, + "outputs": [], + "source": [ + "width, height = rh.raw.shape\n", + "\n", + "# Check if the image is large enough to be cropped.\n", + "if width < crop_size or height < crop_size:\n", + " im = rh.apply_colorspace_transform( colorspace=colorspace)\n", + "else:\n", + " # Calculate the coordinates for the center crop.\n", + " left = (width - crop_size) // 2\n", + " top = (height - crop_size) // 2\n", + "\n", + " # Ensure the top-left corner is on an even pixel coordinate for Bayer alignment.\n", + " if left % 2 != 0:\n", + " left -= 1\n", + " if top % 2 != 0:\n", + " top -= 1\n", + " \n", + " # Calculate the bottom-right corner based on the adjusted top-left corner.\n", + " # Since crop_size is even, right and bottom will also be even.\n", + " right = left + crop_size\n", + " bottom = top + crop_size" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75ded325", + "metadata": {}, + "outputs": [], + "source": [ + "rh.raw[left:right, top:bottom]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb7f6deb", + "metadata": {}, + "outputs": [], + "source": [ + "rhdng.raw" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d95fa141", + "metadata": {}, + "outputs": [], + "source": [ + "imdng = rhdng.as_rgb()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45796c08", + "metadata": {}, + "outputs": [], + "source": [ + "im = rh.as_rgb(dims=(left, right, top, bottom))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "406c5201", + "metadata": {}, + "outputs": [], + "source": [ + "(imdng-im).mean()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4fe26aa9", + "metadata": {}, + "outputs": [], + "source": [ + "im" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45323089", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From 08472057caed9f9d135e3e5b8ab26abe6aa936e4 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Wed, 29 Oct 2025 18:57:07 -0400 Subject: [PATCH 33/56] Cleaned up dng making script --- 0_produce_small_dataset_raw.ipynb | 523 +----------------------------- 1 file changed, 16 insertions(+), 507 deletions(-) diff --git a/0_produce_small_dataset_raw.ipynb b/0_produce_small_dataset_raw.ipynb index 0637156..8d818c2 100644 --- a/0_produce_small_dataset_raw.ipynb +++ b/0_produce_small_dataset_raw.ipynb @@ -307,448 +307,11 @@ { "cell_type": "code", "execution_count": null, - "id": "06af6496", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6dcdae9b", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1213a101", - "metadata": {}, - "outputs": [], - "source": [ - "def get_file(impath, crop_size=crop_size):\n", - " rh = RawHandler(impath)\n", - " \n", - " width, height = rh.raw.shape\n", - "\n", - " # Check if the image is large enough to be cropped.\n", - " if width < crop_size or height < crop_size:\n", - " im = rh.apply_colorspace_transform(colorspace=colorspace)\n", - " else:\n", - " # Calculate the coordinates for the center crop.\n", - " left = (width - crop_size) // 2\n", - " top = (height - crop_size) // 2\n", - "\n", - " # Ensure the top-left corner is on an even pixel coordinate for Bayer alignment.\n", - " if left % 2 != 0:\n", - " left -= 1\n", - " if top % 2 != 0:\n", - " top -= 1\n", - " \n", - " # Calculate the bottom-right corner based on the adjusted top-left corner.\n", - " # Since crop_size is even, right and bottom will also be even.\n", - " right = left + crop_size\n", - " bottom = top + crop_size\n", - "\n", - " im = rh.apply_colorspace_transform(dims=(left, right, top, bottom), colorspace=colorspace)\n", - " im = im.astype(np.float16)\n", - " # im_scaled = im * 65535.0\n", - " # im_clipped = np.clip(im_scaled, 0.0, 65535.0)\n", - "\n", - " # im_uint16 = im_clipped.astype(np.uint16)\n", - " return im, rh.raw[left:right, top:bottom], rh" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "010774f5", - "metadata": {}, - "outputs": [], - "source": [ - "from tqdm import tqdm" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5c0c0816", - "metadata": {}, - "outputs": [], - "source": [ - "for key in tqdm(pair_file_list.keys()):\n", - " image_pairs = pair_file_list[key]\n", - " for noisy, gt in image_pairs:\n", - " try:\n", - " noisy_bayer = get_file(f'{raw_path}/{noisy}')\n", - " noisy_path = outpath / (noisy + \".f16.raw\")\n", - " noisy_bayer.tofile(noisy_path)\n", - "\n", - " gt_path = outpath / (gt + \".f16.raw\")\n", - " if not os.path.exists(gt_path):\n", - " gt_bayer = get_file(f'{raw_path}/{gt}')\n", - " gt_bayer.tofile(gt_path)\n", - " except:\n", - " print(f\"Skipping {raw_path}/{noisy}, {raw_path}/{gt}\")\n", - " break\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "264b6831", - "metadata": {}, - "outputs": [], - "source": [ - "dir(Tag)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "225e097f", - "metadata": {}, - "outputs": [], - "source": [ - "from pidng.core import RAW2DNG, DNGTags, Tag\n", - "from pidng.defs import *\n", - "\n", - "def get_ratios(string, rh):\n", - " return [x.as_integer_ratio() for x in rh.full_metadata[string].values]\n", - "\n", - "\n", - "def rational_wb(rh, denominator=1000):\n", - " wb = np.array(rh.core_metadata.camera_white_balance)\n", - " numerator_matrix = np.round(wb * denominator).astype(int)\n", - " return [[num, denominator] for num in numerator_matrix]\n", - "def convert_ccm_to_rational(matrix_3x3, denominator=10000):\n", - "\n", - " numerator_matrix = np.round(matrix_3x3 * denominator).astype(int)\n", - " numerators_flat = numerator_matrix.flatten()\n", - " ccm_rational = [[num, denominator] for num in numerators_flat]\n", - " \n", - " return ccm_rational\n", - "\n", - "\n", - "def get_as_shot_neutral(rh, denominator=10000):\n", - "\n", - " # Get multipliers [R, G1, B, G2]\n", - " cam_mul = rh.core_metadata.camera_white_balance\n", - " \n", - " # Check for zero multipliers to avoid division by zero\n", - " if cam_mul[0] == 0 or cam_mul[2] == 0:\n", - " # Fallback to [1, 1, 1] if multipliers are bad\n", - " return [[denominator, denominator], [denominator, denominator], [denominator, denominator]]\n", - "\n", - " # Calculate inverse multipliers normalized to G (cam_mul[1])\n", - " # DNG spec AsShotNeutral = [1/R_scale, 1/G_scale, 1/B_scale]\n", - " # where G_scale = 1.0. This means:\n", - " # R_scale = R_mult / G_mult\n", - " # B_scale = B_mult / G_mult\n", - " # So 1/R_scale = G_mult / R_mult\n", - " \n", - " r_neutral = cam_mul[1] / cam_mul[0]\n", - " g_neutral = 1.0 # G is always 1.0\n", - " b_neutral = cam_mul[1] / cam_mul[2]\n", - "\n", - " return [\n", - " [int(r_neutral * denominator), denominator],\n", - " [int(g_neutral * denominator), denominator],\n", - " [int(b_neutral * denominator), denominator],\n", - " ]\n", - "\n", - "\n", - "def to_dng(uint_img, rh, filepath):\n", - " uint_img = np.ascontiguousarray(uint_img)\n", - " width = uint_img.shape[1]\n", - " height = uint_img.shape[0]\n", - " bpp = 16\n", - "\n", - " exposures = get_ratios('EXIF ExposureTime', rh)\n", - " fnumber = get_ratios('EXIF FNumber', rh)\n", - " ExposureBiasValue = get_ratios('EXIF ExposureBiasValue', rh) \n", - " FocalLength = get_ratios('EXIF FocalLength', rh) \n", - " ccm1 = convert_ccm_to_rational(rh.core_metadata.rgb_xyz_matrix[:3, :])\n", - " t = DNGTags()\n", - " t.set(Tag.ImageWidth, width)\n", - " t.set(Tag.ImageLength, height)\n", - " t.set(Tag.TileWidth, width)\n", - " t.set(Tag.TileLength, height)\n", - " t.set(Tag.BitsPerSample, bpp)\n", - "\n", - " t.set(Tag.SamplesPerPixel, 1) \n", - " t.set(Tag.PlanarConfiguration, 1) \n", - "\n", - " t.set(Tag.TileWidth, width)\n", - " t.set(Tag.TileLength, height)\n", - " t.set(Tag.Orientation, rh.full_metadata['Image Orientation'].values[0])\n", - " t.set(Tag.PhotometricInterpretation, PhotometricInterpretation.Color_Filter_Array)\n", - " t.set(Tag.CFARepeatPatternDim, [2,2])\n", - " t.set(Tag.CFAPattern, CFAPattern.RGGB)\n", - " bl = rh.core_metadata.black_level_per_channel\n", - " t.set(Tag.BlackLevelRepeatDim, [2,2])\n", - " t.set(Tag.BlackLevel, bl)\n", - " t.set(Tag.WhiteLevel, rh.core_metadata.white_level)\n", - "\n", - " t.set(Tag.BitsPerSample, bpp)\n", - "\n", - " t.set(Tag.ColorMatrix1, ccm1)\n", - " t.set(Tag.CalibrationIlluminant1, CalibrationIlluminant.D65)\n", - " wb = get_as_shot_neutral(rh)\n", - " print(wb)\n", - " t.set(Tag.AsShotNeutral, wb)\n", - " t.set(Tag.BaselineExposure, [[0,100]])\n", - " t.set(Tag.Make, rh.full_metadata['Image Make'].values)\n", - " t.set(Tag.Model, rh.full_metadata['Image Model'].values)\n", - "\n", - "\n", - "\n", - " t.set(Tag.FocalLength, FocalLength)\n", - " t.set(Tag.EXIFPhotoLensModel, rh.full_metadata['EXIF LensModel'].values)\n", - " t.set(Tag.ExposureBiasValue, ExposureBiasValue)\n", - " t.set(Tag.ExposureTime, exposures)\n", - " t.set(Tag.FNumber, fnumber)\n", - " t.set(Tag.PhotographicSensitivity, rh.full_metadata['EXIF ISOSpeedRatings'].values)\n", - " t.set(Tag.DNGVersion, DNGVersion.V1_4)\n", - " t.set(Tag.DNGBackwardVersion, DNGVersion.V1_2)\n", - " t.set(Tag.PreviewColorSpace, PreviewColorSpace.Adobe_RGB)\n", - "\n", - " r = RAW2DNG()\n", - "\n", - " r.options(t, path=\"\", compress=False)\n", - "\n", - " r.convert(uint_img, filename=filepath)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "26d424c6", - "metadata": {}, - "outputs": [], - "source": [ - "Tag.BlackLevel" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "072cf0d2", - "metadata": {}, - "outputs": [], - "source": [ - "im, raw, rh = get_file(f'{raw_path}/{noisy}')\n", - "width, height = rh.raw.shape\n", - "\n", - "# Check if the image is large enough to be cropped.\n", - "if width < crop_size or height < crop_size:\n", - " im = rh.apply_colorspace_transform( colorspace=colorspace)\n", - "else:\n", - " # Calculate the coordinates for the center crop.\n", - " left = (width - crop_size) // 2\n", - " top = (height - crop_size) // 2\n", - "\n", - " # Ensure the top-left corner is on an even pixel coordinate for Bayer alignment.\n", - " if left % 2 != 0:\n", - " left -= 1\n", - " if top % 2 != 0:\n", - " top -= 1\n", - " \n", - " # Calculate the bottom-right corner based on the adjusted top-left corner.\n", - " # Since crop_size is even, right and bottom will also be even.\n", - " right = left + crop_size\n", - " bottom = top + crop_size" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d9010156", + "id": "9414c7a1", "metadata": {}, "outputs": [], "source": [ - "rh.raw.max()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "28809a03", - "metadata": {}, - "outputs": [], - "source": [ - "im = rh.as_rgb(dims=(left, right, top, bottom))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5137cf23", - "metadata": {}, - "outputs": [], - "source": [ - "rh.core_metadata.raw_pattern" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "562b2efd", - "metadata": {}, - "outputs": [], - "source": [ - "plt.imshow(im.transpose(1, 2, 0)**(1/2.2))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5047b5e9", - "metadata": {}, - "outputs": [], - "source": [ - "rh.raw[left:right, top:bottom]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4faf9bf4", - "metadata": {}, - "outputs": [], - "source": [ - "rhdng.raw" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "66de04c9", - "metadata": {}, - "outputs": [], - "source": [ - "rh.core_metadata" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a1c1bc06", - "metadata": {}, - "outputs": [], - "source": [ - "to_dng(raw, rh, \"test_camera_WB_D65_compress\")\n", - "rhdng = RawHandler(\"test_camera_WB_D65_compress.dng\")\n", - "rhdng.core_metadata" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6509b267", - "metadata": {}, - "outputs": [], - "source": [ - "rhdng.raw.max()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5831855f", - "metadata": {}, - "outputs": [], - "source": [ - "rh.raw[left:right, top:bottom].max()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9c9fe790", - "metadata": {}, - "outputs": [], - "source": [ - "2.80898881*1024" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b7943162", - "metadata": {}, - "outputs": [], - "source": [ - "1.80864525*1024" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a4a368fa", - "metadata": {}, - "outputs": [], - "source": [ - "im2 = rhdng.as_rgb()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a4c73bb2", - "metadata": {}, - "outputs": [], - "source": [ - "plt.imshow(im2.transpose(1, 2, 0)**(1/2.2))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1f25ed2f", - "metadata": {}, - "outputs": [], - "source": [ - "from src.training.load_config import load_config\n", - "from src.training.SmallRawDatasetNumpy import SmallRawDatasetNumpy\n", - "from src.training.censored_fit import censored_linear_fit_twosided\n", - "\n", - "run_config = load_config()\n", - "dataset_path = Path(run_config['cropped_raw_subdir'])\n", - "align_csv = dataset_path / run_config['secondary_align_csv']\n", - "crop_size = 1500" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1c86b449", - "metadata": {}, - "outputs": [], - "source": [ - "dataset = SmallRawDatasetNumpy(dataset_path, align_csv, crop_size=crop_size)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a1e363d7", - "metadata": {}, - "outputs": [], - "source": [ - "output = dataset[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b2f39a99", - "metadata": {}, - "outputs": [], - "source": [ - "plt.imshow(output['aligned'].permute(1, 2, 0))" + "#Testing data is properly copied" ] }, { @@ -758,59 +321,10 @@ "metadata": {}, "outputs": [], "source": [ - "rhdng = RawHandler(\"test_camera_WB_D65_compress.dng\")\n", + "rhdng = RawHandler(str(outpath / \"Bayer_MuseeL-sol-A7C-brighter_ISO100_sha1=18eaa9931d9a0f6f0511552ef6bf2fd040d82878.arw.dng\"))\n", "rh = RawHandler('/Volumes/EasyStore/RAWNIND/Bayer_MuseeL-sol-A7C-brighter_ISO100_sha1=18eaa9931d9a0f6f0511552ef6bf2fd040d82878.arw')\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "3269995e", - "metadata": {}, - "outputs": [], - "source": [ - "width, height = rh.raw.shape\n", - "\n", - "# Check if the image is large enough to be cropped.\n", - "if width < crop_size or height < crop_size:\n", - " im = rh.apply_colorspace_transform( colorspace=colorspace)\n", - "else:\n", - " # Calculate the coordinates for the center crop.\n", - " left = (width - crop_size) // 2\n", - " top = (height - crop_size) // 2\n", - "\n", - " # Ensure the top-left corner is on an even pixel coordinate for Bayer alignment.\n", - " if left % 2 != 0:\n", - " left -= 1\n", - " if top % 2 != 0:\n", - " top -= 1\n", - " \n", - " # Calculate the bottom-right corner based on the adjusted top-left corner.\n", - " # Since crop_size is even, right and bottom will also be even.\n", - " right = left + crop_size\n", - " bottom = top + crop_size" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "75ded325", - "metadata": {}, - "outputs": [], - "source": [ - "rh.raw[left:right, top:bottom]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb7f6deb", - "metadata": {}, - "outputs": [], - "source": [ - "rhdng.raw" - ] - }, { "cell_type": "code", "execution_count": null, @@ -828,6 +342,19 @@ "metadata": {}, "outputs": [], "source": [ + "width, height = rh.raw.shape\n", + "\n", + "left = (width - crop_size) // 2\n", + "top = (height - crop_size) // 2\n", + "\n", + "if left % 2 != 0:\n", + " left -= 1\n", + "if top % 2 != 0:\n", + " top -= 1\n", + "\n", + "right = left + crop_size\n", + "bottom = top + crop_size\n", + "\n", "im = rh.as_rgb(dims=(left, right, top, bottom))" ] }, @@ -840,24 +367,6 @@ "source": [ "(imdng-im).mean()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4fe26aa9", - "metadata": {}, - "outputs": [], - "source": [ - "im" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "45323089", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { From 2c508a98fc9b2e78a0cfbdcbc4ed7f729738a5f9 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 30 Oct 2025 15:07:44 -0400 Subject: [PATCH 34/56] Updated censored fit --- src/training/censored_fit.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/src/training/censored_fit.py b/src/training/censored_fit.py index 6b41e4d..f5f49b9 100644 --- a/src/training/censored_fit.py +++ b/src/training/censored_fit.py @@ -1,7 +1,8 @@ import numpy as np from scipy.stats import norm -def censored_linear_fit_twosided(x, y, clip_low=None, clip_high=None, max_iter=200, tol=1e-6, include_offset=True): +def censored_linear_fit_twosided(x, y, clip_low=None, clip_high=None, + max_iter=200, tol=1e-6, include_offset=True): """ Fit y ≈ a + b*x + ε, ε ~ N(0, σ²) under two-sided censoring: clip_low ≤ y_true ≤ clip_high @@ -18,33 +19,29 @@ def censored_linear_fit_twosided(x, y, clip_low=None, clip_high=None, max_iter=2 Maximum EM iterations. tol : float Relative tolerance for convergence. - include_offset: bool - Compute linear fit with offset (y = b * x + a) + include_offset : bool + If False, forces intercept a=0 (fit y ≈ b*x). + Returns ------- a, b, sigma : floats Estimated regression parameters. """ - x = np.asarray(x).ravel() y = np.asarray(y).ravel() mask = np.isfinite(x) & np.isfinite(y) x, y = x[mask], y[mask] - n = len(x) - if n < 3: + if len(x) < 3: raise ValueError("Not enough data points.") # --- initial guess (ordinary least squares) --- if include_offset: A = np.vstack([np.ones_like(x), x]).T + a, b = np.linalg.lstsq(A, y, rcond=None)[0] else: - A = np.vstack([x]).T - coef, *_ = np.linalg.lstsq(A, y, rcond=None) - if include_offset: - a, b = coef - else: - a, b = 0, coef[0] + b = np.dot(x, y) / np.dot(x, x) + a = 0.0 sigma = np.std(y - (a + b*x)) for _ in range(max_iter): @@ -79,18 +76,17 @@ def censored_linear_fit_twosided(x, y, clip_low=None, clip_high=None, max_iter=2 # M-step: re-fit with imputed expectations if include_offset: A = np.vstack([np.ones_like(x), x]).T + a_new, b_new = np.linalg.lstsq(A, y_exp, rcond=None)[0] else: - A = np.vstack([x]).T - coef, *_ = np.linalg.lstsq(A, y, rcond=None) - if include_offset: - a_new, b_new = coef - else: - a_new, b_new = 0, coef[0] + b_new = np.dot(x, y_exp) / np.dot(x, x) + a_new = 0.0 + sigma_new = np.std(y_exp - (a_new + b_new*x)) - if np.allclose([a, b, sigma], [a_new, b_new, sigma_new], rtol=tol, atol=tol): + if np.allclose([a, b, sigma], [a_new, b_new, sigma_new], + rtol=tol, atol=tol): break + a, b, sigma = a_new, b_new, sigma_new return a, b, sigma - From ea59c0c631fe8c9aba2cb6fb77247dbcaf1b520c Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 30 Oct 2025 15:08:19 -0400 Subject: [PATCH 35/56] Dataset for raw images --- src/training/RawDatasetDNG.py | 124 ++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 src/training/RawDatasetDNG.py diff --git a/src/training/RawDatasetDNG.py b/src/training/RawDatasetDNG.py new file mode 100644 index 0000000..103d40f --- /dev/null +++ b/src/training/RawDatasetDNG.py @@ -0,0 +1,124 @@ +import pandas as pd +import os +from torch.utils.data import Dataset +import imageio +from colour_demosaicing import ( + ROOT_RESOURCES_EXAMPLES, + demosaicing_CFA_Bayer_bilinear, + demosaicing_CFA_Bayer_Malvar2004, + demosaicing_CFA_Bayer_Menon2007, + mosaicing_CFA_Bayer) + +from src.training.utils import inverse_gamma_tone_curve, cfa_to_sparse +import numpy as np +import torch +from src.training.align_images import apply_alignment, align_clean_to_noisy +from pathlib import Path +from RawHandler.RawHandler import RawHandler + + + +def global_affine_match(A, D, mask=None): + """ + Fit D ≈ a + b*A with least squares. + A, D : 2D arrays, same shape (linear values) + mask : optional boolean array, True=use pixel + returns: a, b, D_pred, D_resid (D - (a + b*A)) + """ + A = A.ravel().astype(np.float64) + D = D.ravel().astype(np.float64) + if mask is None: + mask = np.isfinite(A) & np.isfinite(D) + else: + mask = mask.ravel() & np.isfinite(A) & np.isfinite(D) + + A0 = A[mask] + D0 = D[mask] + # design matrix [1, A] + X = np.vstack([np.ones_like(A0), A0]).T + coef, *_ = np.linalg.lstsq(X, D0, rcond=None) + a, b = coef[0], coef[1] + D_pred = (a + b * A).reshape(-1) + D_pred = D_pred.reshape(A.shape) if False else (a + b * A).reshape((-1,)) # keep flatten + + return a, b, (a + b * A) + + +def random_crop_dim(shape, crop_size, buffer, validation=False): + h, w = shape + if not validation: + top = np.random.randint(0 + buffer, h - crop_size - buffer) + left = np.random.randint(0 + buffer, w - crop_size - buffer) + else: + top = (h - crop_size) // 2 + left = (w - crop_size) // 2 + + if top % 2 != 0: top = top - 1 + if left % 2 != 0: left = left - 1 + bottom = top + crop_size + right = left + crop_size + return (left, right, top, bottom) + +class RawDatasetDNG(Dataset): + def __init__(self, path, csv, colorspace, crop_size=180, buffer=10, validation=False, run_align=False, dimensions=2000): + super().__init__() + self.df = pd.read_csv(csv) + self.path = path + self.crop_size = crop_size + self.buffer = buffer + self.coordinate_iso = 6400 + self.validation=validation + self.run_align = run_align + self.dtype = np.float16 + self.dimensions = dimensions + self.colorspace = colorspace + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + # Load images + try: + name = Path(f"{row.bayer_path}").name + name = str(self.path / name.replace('_bayer.jpg', '.dng')) + noisy_rh = RawHandler(name) + except: + print(name) + + try: + gt_name = Path(f"{row.gt_path}").name + gt_name = str(self.path / gt_name.replace('.jpg', '.dng')) + gt_rh = RawHandler(gt_name) + except: + print(gt_name) + + dims = random_crop_dim(noisy_rh.raw.shape, self.crop_size, self.buffer, validation=self.validation) + try: + bayer_data = noisy_rh.apply_colorspace_transform(dims=dims, colorspace=self.colorspace) + noisy = noisy_rh.as_rgb(dims=dims, colorspace=self.colorspace) + rggb = noisy_rh.as_rggb(dims=dims, colorspace=self.colorspace) + sparse = noisy_rh.as_sparse(dims=dims, colorspace=self.colorspace) + + expanded_dims = [dims[0]-self.buffer, dims[1]+self.buffer, dims[2]-self.buffer, dims[3]+self.buffer] + gt_expanded = gt_rh.as_rgb(dims=expanded_dims, colorspace=self.colorspace) + aligned = apply_alignment(gt_expanded.transpose(1, 2, 0), row.to_dict())[self.buffer:-self.buffer, self.buffer:-self.buffer] + gt_non_aligned = gt_expanded.transpose(1, 2, 0)[self.buffer:-self.buffer, self.buffer:-self.buffer] + except: + print(name, gt_name) + + # aligned = gt_rh.as_rgb(dims=dims, colorspace=self.colorspace).transpose(1, 2, 0) + + # Convert to tensors + output = { + "bayer": torch.tensor(bayer_data).to(float).clip(0,1), + "gt_non_aligned": torch.tensor(gt_non_aligned).to(float).permute(2, 0, 1).clip(0,1), + "aligned": torch.tensor(aligned).to(float).permute(2, 0, 1).clip(0,1), + "sparse": torch.tensor(sparse).to(float).clip(0,1), + "noisy": torch.tensor(noisy).to(float).clip(0,1), + "rggb": torch.tensor(rggb).to(float).clip(0,1), + "conditioning": torch.tensor([row.iso/self.coordinate_iso]).to(float), + # "noise_est": noise_est, + # "rggb_gt": rggb_gt, + } + return output \ No newline at end of file From 7976b44357ebfcbba50a142b14d5518379a53aaf Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Thu, 30 Oct 2025 16:56:40 -0400 Subject: [PATCH 36/56] DNG training --- 0_produce_small_dataset_raw.ipynb | 284 +++++++++++++++++++++++++++--- 2_pretrain_model_raw.ipynb | 34 +++- 2 files changed, 289 insertions(+), 29 deletions(-) diff --git a/0_produce_small_dataset_raw.ipynb b/0_produce_small_dataset_raw.ipynb index 8d818c2..48e83c2 100644 --- a/0_produce_small_dataset_raw.ipynb +++ b/0_produce_small_dataset_raw.ipynb @@ -226,10 +226,6 @@ " height = uint_img.shape[0]\n", " bpp = 16\n", "\n", - " exposures = get_ratios('EXIF ExposureTime', rh)\n", - " fnumber = get_ratios('EXIF FNumber', rh)\n", - " ExposureBiasValue = get_ratios('EXIF ExposureBiasValue', rh) \n", - " FocalLength = get_ratios('EXIF FocalLength', rh) \n", " ccm1 = convert_ccm_to_rational(rh.core_metadata.rgb_xyz_matrix[:3, :])\n", " t = DNGTags()\n", " t.set(Tag.ImageWidth, width)\n", @@ -243,7 +239,6 @@ "\n", " t.set(Tag.TileWidth, width)\n", " t.set(Tag.TileLength, height)\n", - " t.set(Tag.Orientation, rh.full_metadata['Image Orientation'].values[0])\n", " t.set(Tag.PhotometricInterpretation, PhotometricInterpretation.Color_Filter_Array)\n", " t.set(Tag.CFARepeatPatternDim, [2,2])\n", " t.set(Tag.CFAPattern, CFAPattern.RGGB)\n", @@ -259,17 +254,25 @@ " wb = get_as_shot_neutral(rh)\n", " t.set(Tag.AsShotNeutral, wb)\n", " t.set(Tag.BaselineExposure, [[0,100]])\n", - " t.set(Tag.Make, rh.full_metadata['Image Make'].values)\n", - " t.set(Tag.Model, rh.full_metadata['Image Model'].values)\n", "\n", "\n", "\n", - " t.set(Tag.FocalLength, FocalLength)\n", - " t.set(Tag.EXIFPhotoLensModel, rh.full_metadata['EXIF LensModel'].values)\n", - " t.set(Tag.ExposureBiasValue, ExposureBiasValue)\n", - " t.set(Tag.ExposureTime, exposures)\n", - " t.set(Tag.FNumber, fnumber)\n", - " t.set(Tag.PhotographicSensitivity, rh.full_metadata['EXIF ISOSpeedRatings'].values)\n", + " try:\n", + " t.set(Tag.Make, rh.full_metadata['Image Make'].values)\n", + " t.set(Tag.Model, rh.full_metadata['Image Model'].values)\n", + " exposures = get_ratios('EXIF ExposureTime', rh)\n", + " fnumber = get_ratios('EXIF FNumber', rh)\n", + " ExposureBiasValue = get_ratios('EXIF ExposureBiasValue', rh) \n", + " FocalLength = get_ratios('EXIF FocalLength', rh) \n", + " t.set(Tag.FocalLength, FocalLength)\n", + " t.set(Tag.EXIFPhotoLensModel, rh.full_metadata['EXIF LensModel'].values)\n", + " t.set(Tag.ExposureBiasValue, ExposureBiasValue)\n", + " t.set(Tag.ExposureTime, exposures)\n", + " t.set(Tag.FNumber, fnumber)\n", + " t.set(Tag.PhotographicSensitivity, rh.full_metadata['EXIF ISOSpeedRatings'].values)\n", + " t.set(Tag.Orientation, rh.full_metadata['Image Orientation'].values[0])\n", + " except:\n", + " \"ok\"\n", " t.set(Tag.DNGVersion, DNGVersion.V1_4)\n", " t.set(Tag.DNGBackwardVersion, DNGVersion.V1_2)\n", " t.set(Tag.PreviewColorSpace, PreviewColorSpace.Adobe_RGB)\n", @@ -291,17 +294,19 @@ "for key in tqdm(pair_file_list.keys()):\n", " image_pairs = pair_file_list[key]\n", " for noisy, gt in image_pairs:\n", - " try:\n", - " bayer, rh = get_file(f'{raw_path}/{noisy}')\n", - " noisy_path = outpath / (noisy)\n", - " to_dng(bayer, rh, str(noisy_path))\n", - "\n", - " gt_path = outpath / (gt)\n", - " if not os.path.exists(gt_path):\n", - " bayer, rh = get_file(f'{raw_path}/{gt}')\n", - " to_dng(bayer, rh, str(gt_path))\n", - " except:\n", - " print(f\"Skipping {raw_path}/{noisy}, {raw_path}/{gt}\")\n" + " noisy_path = outpath / (noisy)\n", + " if not os.path.exists(str(noisy_path)+'.dng'):\n", + " print(noisy_path)\n", + " if noisy.endswith(('.cr2', '.nef', '.arw', '.orf', '.raf', '.pef', '.crw', '.dng')):\n", + " bayer, rh = get_file(f'{raw_path}/{noisy}')\n", + " to_dng(bayer, rh, str(noisy_path))\n", + "\n", + "\n", + " gt_path = outpath / (gt)\n", + " if not os.path.exists(str(gt_path)+'.dng'):\n", + " print(gt_path)\n", + " bayer, rh = get_file(f'{raw_path}/{gt}')\n", + " to_dng(bayer, rh, str(gt_path))\n" ] }, { @@ -367,6 +372,237 @@ "source": [ "(imdng-im).mean()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c495ba8", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import os\n", + "from torch.utils.data import Dataset\n", + "import imageio\n", + "from colour_demosaicing import (\n", + " ROOT_RESOURCES_EXAMPLES,\n", + " demosaicing_CFA_Bayer_bilinear,\n", + " demosaicing_CFA_Bayer_Malvar2004,\n", + " demosaicing_CFA_Bayer_Menon2007,\n", + " mosaicing_CFA_Bayer)\n", + "\n", + "from src.training.utils import inverse_gamma_tone_curve, cfa_to_sparse\n", + "import numpy as np\n", + "import torch \n", + "from src.training.align_images import apply_alignment, align_clean_to_noisy\n", + "from pathlib import Path\n", + "from RawHandler.RawHandler import RawHandler\n", + "\n", + "\n", + "\n", + "def global_affine_match(A, D, mask=None):\n", + " \"\"\"\n", + " Fit D ≈ a + b*A with least squares.\n", + " A, D : 2D arrays, same shape (linear values)\n", + " mask : optional boolean array, True=use pixel\n", + " returns: a, b, D_pred, D_resid (D - (a + b*A))\n", + " \"\"\"\n", + " A = A.ravel().astype(np.float64)\n", + " D = D.ravel().astype(np.float64)\n", + " if mask is None:\n", + " mask = np.isfinite(A) & np.isfinite(D)\n", + " else:\n", + " mask = mask.ravel() & np.isfinite(A) & np.isfinite(D)\n", + "\n", + " A0 = A[mask]\n", + " D0 = D[mask]\n", + " # design matrix [1, A]\n", + " X = np.vstack([np.ones_like(A0), A0]).T\n", + " coef, *_ = np.linalg.lstsq(X, D0, rcond=None)\n", + " a, b = coef[0], coef[1]\n", + " D_pred = (a + b * A).reshape(-1)\n", + " D_pred = D_pred.reshape(A.shape) if False else (a + b * A).reshape((-1,)) # keep flatten\n", + "\n", + " return a, b, (a + b * A)\n", + "\n", + "\n", + "def random_crop_dim(shape, crop_size, buffer, validation=False):\n", + " h, w = shape\n", + " if not validation:\n", + " top = np.random.randint(0 + buffer, h - crop_size - buffer)\n", + " left = np.random.randint(0 + buffer, w - crop_size - buffer)\n", + " else:\n", + " top = (h - crop_size) // 2\n", + " left = (w - crop_size) // 2\n", + "\n", + " if top % 2 != 0: top = top - 1\n", + " if left % 2 != 0: left = left - 1\n", + " bottom = top + crop_size\n", + " right = left + crop_size\n", + " return (left, right, top, bottom)\n", + "\n", + "class RawDatasetDNG(Dataset):\n", + " def __init__(self, path, csv, colorspace, crop_size=180, buffer=10, validation=False, run_align=False, dimensions=2000):\n", + " super().__init__()\n", + " self.df = pd.read_csv(csv)\n", + " self.path = path\n", + " self.crop_size = crop_size\n", + " self.buffer = buffer\n", + " self.coordinate_iso = 6400\n", + " self.validation=validation\n", + " self.run_align = run_align\n", + " self.dtype = np.float16\n", + " self.dimensions = dimensions\n", + " self.colorspace = colorspace\n", + "\n", + " def __len__(self):\n", + " return len(self.df)\n", + " \n", + " def __getitem__(self, idx):\n", + " row = self.df.iloc[idx]\n", + " # Load images\n", + " name = Path(f\"{row.bayer_path}\").name\n", + " name = str(self.path / name.replace('_bayer.jpg', '.dng'))\n", + " noisy_rh = RawHandler(name)\n", + " \n", + " name = Path(f\"{row.gt_path}\").name\n", + " name = str(self.path / name.replace('.jpg', '.dng'))\n", + " gt_rh = RawHandler(name)\n", + "\n", + "\n", + " dims = random_crop_dim(noisy_rh.raw.shape, self.crop_size, self.buffer, validation=self.validation)\n", + " bayer_data = noisy_rh.apply_colorspace_transform(dims=dims, colorspace=self.colorspace)\n", + " noisy = noisy_rh.as_rgb(dims=dims, colorspace=self.colorspace)\n", + " rggb = noisy_rh.as_rggb(dims=dims, colorspace=self.colorspace)\n", + "\n", + " expanded_dims = [dims[0]-self.buffer, dims[1]+self.buffer, dims[0]-self.buffer, dims[1]+self.buffer]\n", + " gt_expanded = gt_rh.as_rgb(dims=expanded_dims, colorspace=self.colorspace)\n", + " aligned = apply_alignment(gt_expanded.transpose(1, 2, 0), row.to_dict())[self.buffer:-self.buffer, self.buffer:-self.buffer]\n", + " gt_non_aligned = gt_expanded.transpose(1, 2, 0)[self.buffer:-self.buffer, self.buffer:-self.buffer]\n", + " # Convert to tensors\n", + " output = {\n", + " \"bayer\": torch.tensor(bayer_data).to(float).clip(0,1), \n", + " \"gt_non_aligned\": torch.tensor(gt_non_aligned).to(float).permute(2, 0, 1).clip(0,1), \n", + " \"aligned\": torch.tensor(aligned).to(float).permute(2, 0, 1).clip(0,1), \n", + " # \"sparse\": torch.tensor(sparse).to(float).clip(0,1),\n", + " \"noisy\": torch.tensor(noisy).to(float).clip(0,1), \n", + " \"rggb\": torch.tensor(rggb).to(float).clip(0,1),\n", + " \"conditioning\": torch.tensor([row.iso/self.coordinate_iso]).to(float), \n", + " # \"noise_est\": noise_est,\n", + " # \"rggb_gt\": rggb_gt,\n", + " }\n", + " return output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04525aa2", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.load_config import load_config\n", + "\n", + "run_config = load_config()\n", + "dataset_path = Path(run_config['cropped_raw_subdir'])\n", + "align_csv = dataset_path / run_config['secondary_align_csv']\n", + "\n", + "\n", + "device=run_config['device']\n", + "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", + "\n", + "num_epochs = run_config['num_epochs_pretraining']\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "mlflow_path = run_config['mlflow_path']\n", + "colorspace = run_config['colorspace']\n", + "\n", + "rggb = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c529a7f1", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = SmallRawDatasetNumpy(dataset_path, align_csv, colorspace, crop_size=crop_size, validation=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2b506c1", + "metadata": {}, + "outputs": [], + "source": [ + "output = dataset[0]\n", + "import matplotlib.pyplot as plt\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5ee6bbe", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.imshow(output['noisy'].permute(1, 2, 0)**(1/2.2))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e05c86a", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.imshow((output['noisy']-output['aligned']).permute(1, 2, 0)+0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6036f9a3", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.imshow((output['noisy']-output['gt_non_aligned']).permute(1, 2, 0)+0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "203687d3", + "metadata": {}, + "outputs": [], + "source": [ + "output['gt_non_aligned'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f684a2f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "819fa74d", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/2_pretrain_model_raw.ipynb b/2_pretrain_model_raw.ipynb index 44de687..e20ea90 100644 --- a/2_pretrain_model_raw.ipynb +++ b/2_pretrain_model_raw.ipynb @@ -25,7 +25,7 @@ "metadata": {}, "outputs": [], "source": [ - "from src.training.SmallRawDatasetNumpy import SmallRawDatasetNumpy\n", + "from src.training.RawDatasetDNG import RawDatasetDNG\n", "from src.training.losses.ShadowAwareLoss import ShadowAwareLoss\n", "from src.training.VGGFeatureExtractor import VGGFeatureExtractor\n", "from src.training.train_loop import train_one_epoch, visualize\n", @@ -74,6 +74,8 @@ "crop_size = run_config['crop_size']\n", "experiment = run_config['mlflow_experiment']\n", "mlflow_path = run_config['mlflow_path']\n", + "colorspace = run_config['colorspace']\n", + "\n", "rggb = True\n", "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", "mlflow.set_experiment(experiment)" @@ -87,7 +89,7 @@ "outputs": [], "source": [ "\n", - "RUN_ID = \"425568ac95d340d7a59c624233269207\" \n", + "RUN_ID = \"eb0face7e288444e9aa20ffdb8ccdb23\" \n", "ARTIFACT_PATH = run_config['run_path']\n", "\n", "model_uri = f\"runs:/{RUN_ID}/{ARTIFACT_PATH}\"\n", @@ -109,7 +111,9 @@ "metadata": {}, "outputs": [], "source": [ - "dataset = SmallRawDatasetNumpy(dataset_path, align_csv, crop_size=crop_size)\n", + "dataset = RawDatasetDNG(dataset_path, align_csv, colorspace, crop_size=crop_size)\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('crw')]\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('dng_bayer')]\n", "\n", "# Split dataset into train and val\n", "val_size = int(len(dataset) * val_split)\n", @@ -120,8 +124,8 @@ "val_dataset = copy.deepcopy(val_dataset)\n", "val_dataset.dataset.validation = True\n", "\n", - "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n", - "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)" + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)" ] }, { @@ -146,12 +150,24 @@ " ssim_weight=run_config['ssim_weight'],\n", " tv_weight=run_config['tv_weight'],\n", " vgg_loss_weight=run_config['vgg_loss_weight'],\n", + " percept_loss_weight=run_config['percept_loss_weight'],\n", " apply_gamma_fn=apply_gamma_torch,\n", " vgg_feature_extractor=vfe,\n", " device=device,\n", ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2bea1ed", + "metadata": {}, + "outputs": [], + "source": [ + "from RawHandler.RawHandler import RawHandler\n", + "RawHandler('/Users/ryanmueller/Develop/Cropped_Raw/Bayer_TEST_Pen-pile_ISO400_sha1=fad719aac1eccd9a0e87409fd6c5249ab21aa60c.dng.dng')" + ] + }, { "cell_type": "code", "execution_count": null, @@ -180,6 +196,14 @@ "source": [ "run.info.run_id" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf1f13e1", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From 46dfc9a543f96597b15aa06f45410a441013ea61 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 31 Oct 2025 13:57:04 -0400 Subject: [PATCH 37/56] Updated gitignore --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 10f7995..663e14a 100644 --- a/.gitignore +++ b/.gitignore @@ -199,4 +199,6 @@ mlruns/ .DS_store *.png *.jpeg -*.csv \ No newline at end of file +*.csv +*.xmp +*.dng \ No newline at end of file From f1b398152980eeac8b3774026b92d69b0e0cb58d Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 31 Oct 2025 13:57:26 -0400 Subject: [PATCH 38/56] Aditional testing options in cond naf --- src/Restorer/Cond_NAF.py | 180 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 171 insertions(+), 9 deletions(-) diff --git a/src/Restorer/Cond_NAF.py b/src/Restorer/Cond_NAF.py index efebe4a..b43cda2 100644 --- a/src/Restorer/Cond_NAF.py +++ b/src/Restorer/Cond_NAF.py @@ -2,6 +2,28 @@ import torch import torch.nn as nn + +class LayerNorm2dAdjusted(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x, target_mu, target_var): + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + + y = (x - mu) / torch.sqrt(var + self.eps) + + y = y * torch.sqrt(target_var + self.eps) + target_mu + + weight_view = self.weight.view(1, self.weight.size(0), 1, 1) + bias_view = self.bias.view(1, self.bias.size(0), 1, 1) + + y = weight_view * y + bias_view + return y + class LayerNorm2d(nn.Module): def __init__(self, channels, eps=1e-6): super(LayerNorm2d, self).__init__() @@ -369,6 +391,113 @@ def forward(self, input): return (y + x * self.gamma, cond) +class NAFBlock0AdjustedNorm(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2dAdjusted(c) + self.norm2 = LayerNorm2dAdjusted(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.sca_mul = ConditionedChannelAttention(c, cond_chans) + self.sca_add = ConditionedChannelAttention(c, cond_chans) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x, mu, var) + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + normed = self.norm2(y, mu, var) + + # Input mediated channel attention, obstensibly to mitigate the effects of group norm on flat scenes + # x = (1 + self.sca_mul(inp, cond)) * normed + self.sca_add(inp, cond) + + x = self.conv4(normed) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return (y + x * self.gamma, cond, mu, var) + import torch.nn.functional as F @@ -501,12 +630,15 @@ def __init__( use_cond_net = False, cond_net_num = 32, use_input_stats=False, + use_NAFBlock0AdjustedNorm=False, ): super().__init__() if use_attnblock: block = AttnBlock elif use_NAFBlock0_learned_norm: block = NAFBlock0_learned_norm + elif use_NAFBlock0AdjustedNorm: + block = NAFBlock0AdjustedNorm else: block = NAFBlock0 @@ -640,14 +772,14 @@ def forward(self, inp, cond_in): # Conditioning: cond = self.conditioning_gen(cond_in) - if self.use_cond_net: - extra_cond = self.cond_net(inp) - cond = torch.cat([cond, extra_cond], dim=1) - if self.use_input_stats: - mu = inp.mean((2,3), keepdim=True) - var = (inp - mu).pow(2).mean((2,3), keepdim=False) - mu = mu.squeeze(-1).squeeze(-1) - cond = torch.cat([cond, mu, var], dim=1) + # if self.use_cond_net: + # extra_cond = self.cond_net(inp) + # cond = torch.cat([cond, extra_cond], dim=1) + # if self.use_input_stats: + # mu = inp.mean((2,3), keepdim=True) + # var = (inp - mu).pow(2).mean((2,3), keepdim=False) + # mu = mu.squeeze(-1).squeeze(-1) + # cond = torch.cat([cond, mu, var], dim=1) B, C, H, W = inp.shape if self.rggb: @@ -682,18 +814,48 @@ def check_image_size(self, x): class ModelWrapper(nn.Module): def __init__(self, **kwargs): + self.gamma = 1 + if 'gamma' in kwargs: + self.gamma = kwargs.pop('gamma') super().__init__() self.model = Restorer( **kwargs ) def forward(self, x, cond, residual): + x = x.clip(0, 1) ** (1. / self.gamma) + residual = residual.clip(0, 1) ** (1. / self.gamma) output = self.model(x, cond) - return residual + output + output = (residual + output).clip(0, 1) ** (self.gamma) + return output def make_full_model_RGGB(params, model_name=None): model = ModelWrapper(**params) + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + + +class DemosaicingModelWrapper(nn.Module): + def __init__(self, **kwargs): + self.gamma = 1 + if 'gamma' in kwargs: + self.gamma = kwargs.pop('gamma') + super().__init__() + self.model = Restorer( + **kwargs + ) + + def forward(self, x, cond): + x = x.clip(0, 1) ** (1. / self.gamma) + output = (self.model(x, cond)).clip(0,1) ** (self.gamma) + return output + + +def make_full_model_RGGB_Demosaicing(params, model_name=None): + model = DemosaicingModelWrapper(**params) if not model_name is None: state_dict = torch.load(model_name, map_location="cpu") model.load_state_dict(state_dict) From e15541fb4ebef56c7d59f6f3c3c5a60e3c9dd3e0 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 31 Oct 2025 13:57:42 -0400 Subject: [PATCH 39/56] Added sharpness loss to shadow aware loss --- src/training/losses/ShadowAwareLoss.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/training/losses/ShadowAwareLoss.py b/src/training/losses/ShadowAwareLoss.py index 3714932..cf56db1 100644 --- a/src/training/losses/ShadowAwareLoss.py +++ b/src/training/losses/ShadowAwareLoss.py @@ -2,6 +2,7 @@ import torch.nn as nn from pytorch_msssim import ms_ssim from src.training.losses.CombinedPerceptualLoss import VGGPerceptualLoss +import torchvision class ShadowAwareLoss(nn.Module): def __init__(self, @@ -14,7 +15,8 @@ def __init__(self, apply_gamma_fn=None, vgg_feature_extractor=None, percept_loss_weight = 0, - device=None): + device=None, + sharpness_loss_weight=0): """ Shadow-aware image restoration loss. @@ -38,6 +40,7 @@ def __init__(self, self.device = device self.percept_loss_weight = percept_loss_weight self.VGGPerceptualLoss = VGGPerceptualLoss() + self.sharpness_loss_weight = sharpness_loss_weight if device is not None: self.to(device) @@ -83,13 +86,26 @@ def forward(self, pred, target): percept_loss = 0 if self.percept_loss_weight: percept_loss = self.VGGPerceptualLoss(pred, target) + + sharpness_loss_value = 0 + if self.sharpness_loss_weight: + sharpness_loss_value += sharpness_loss(pred, target) + # Combine weighted terms total_loss = ( self.l1_weight * l1 + self.ssim_weight * ssim + self.tv_weight * tv + self.vgg_loss_weight * vgg_loss_val + - self.percept_loss_weight * percept_loss + self.percept_loss_weight * percept_loss + + self.sharpness_loss_weight * sharpness_loss_value ) return total_loss + +def sharpness_loss(pred, target, loss_func = torch.nn.functional.l1_loss): + loss = loss_func(pred, target) + pred_prime = torchvision.transforms.functional.gaussian_blur(pred, kernel_size=[5, 5], sigma=[1.0, 1.0]) + target_prime = torchvision.transforms.functional.gaussian_blur(target, kernel_size=[5, 5], sigma=[1.0, 1.0]) + loss += loss_func(pred-pred_prime, target-target_prime) + return loss From d8353006648ed9aeec71fe6890e35a58897025ab Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 31 Oct 2025 13:57:59 -0400 Subject: [PATCH 40/56] Tracking lr in training loop --- src/training/train_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/training/train_loop.py b/src/training/train_loop.py index 7c75292..e79a5d8 100644 --- a/src/training/train_loop.py +++ b/src/training/train_loop.py @@ -63,6 +63,7 @@ def train_one_epoch(epoch, _model, _optimizer, _loader, _device, _loss_func, _cl mlflow.log_metric("train_loss", total_loss/n_images, step=epoch) mlflow.log_metric("l1_loss", total_l1_loss/n_images, step=epoch) mlflow.log_metric("epoch_duration_s", train_time, step=epoch) + mlflow.log_metric("learning_rate", _optimizer.param_groups[0]['lr'], step=epoch) return total_loss / max(1, n_images), perf_counter()-start From 81fbfc51d5c8d02ebc6bc516b8646fa4143d9529 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 31 Oct 2025 13:58:17 -0400 Subject: [PATCH 41/56] Updates to raw dataset --- src/training/SmallRawDatasetNumpy.py | 50 ++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/src/training/SmallRawDatasetNumpy.py b/src/training/SmallRawDatasetNumpy.py index f624699..15dd614 100644 --- a/src/training/SmallRawDatasetNumpy.py +++ b/src/training/SmallRawDatasetNumpy.py @@ -15,6 +15,33 @@ from src.training.align_images import apply_alignment, align_clean_to_noisy from pathlib import Path + + +def global_affine_match(A, D, mask=None): + """ + Fit D ≈ a + b*A with least squares. + A, D : 2D arrays, same shape (linear values) + mask : optional boolean array, True=use pixel + returns: a, b, D_pred, D_resid (D - (a + b*A)) + """ + A = A.ravel().astype(np.float64) + D = D.ravel().astype(np.float64) + if mask is None: + mask = np.isfinite(A) & np.isfinite(D) + else: + mask = mask.ravel() & np.isfinite(A) & np.isfinite(D) + + A0 = A[mask] + D0 = D[mask] + # design matrix [1, A] + X = np.vstack([np.ones_like(A0), A0]).T + coef, *_ = np.linalg.lstsq(X, D0, rcond=None) + a, b = coef[0], coef[1] + D_pred = (a + b * A).reshape(-1) + D_pred = D_pred.reshape(A.shape) if False else (a + b * A).reshape((-1,)) # keep flatten + + return a, b, (a + b * A) + class SmallRawDatasetNumpy(Dataset): def __init__(self, path, csv, crop_size=180, buffer=10, validation=False, run_align=False, dimensions=2000): super().__init__() @@ -40,7 +67,7 @@ def __getitem__(self, idx): bayer_data = bayer_data.reshape((self.dimensions, self.dimensions)) name = Path(f"{row.gt_path}").name - name = name.replace('jpg', 'u16.raw') + name = name.replace('jpg', 'u16.raw') gt_image = np.fromfile(self.path / name, dtype=self.dtype) gt_image = gt_image.reshape((self.dimensions, self.dimensions)) @@ -78,21 +105,32 @@ def __getitem__(self, idx): h, w, _ = gt_image.shape demosaiced_noisy = demosaicing_CFA_Bayer_Malvar2004(bayer_data) - - aligned = aligned * demosaiced_noisy.mean() / aligned.mean() - gt_image = gt_image * demosaiced_noisy.mean() / gt_image.mean() - sparse, _ = cfa_to_sparse(bayer_data) rggb = bayer_data.reshape(h // 2, 2, w // 2, 2, 1).transpose(1, 3, 4, 0, 2).reshape(4, h // 2, w // 2) + # # Affine transform to match brightness in gt to noisy + # a, b, aligned = global_affine_match(aligned, demosaiced_noisy) + + # print(a, b, aligned.shape) + # gt_image = gt_image * demosaiced_noisy.mean() / gt_image.mean() + # aligned = aligned * demosaiced_noisy.mean() / aligned.mean() + + + # Sim noise method + mosaiced_gt = mosaicing_CFA_Bayer(aligned) + rggb_gt = mosaiced_gt.reshape(h // 2, 2, w // 2, 2, 1).transpose(1, 3, 4, 0, 2).reshape(4, h // 2, w // 2) + noise_est = rggb - rggb_gt + # Convert to tensors output = { "bayer": torch.tensor(bayer_data).to(float).clip(0,1), - "gt": torch.tensor(gt_image).to(float).permute(2, 0, 1).clip(0,1), + "gt_non_aligned": torch.tensor(gt_image).to(float).permute(2, 0, 1).clip(0,1), "aligned": torch.tensor(aligned).to(float).permute(2, 0, 1).clip(0,1), "sparse": torch.tensor(sparse).to(float).clip(0,1), "noisy": torch.tensor(demosaiced_noisy).to(float).permute(2, 0, 1).clip(0,1), "rggb": torch.tensor(rggb).to(float).clip(0,1), "conditioning": torch.tensor([row.iso/self.coordinate_iso]).to(float), + "noise_est": noise_est, + "rggb_gt": rggb_gt, } return output \ No newline at end of file From 4e739400994e595f558e4e42b61c5f105f40f06a Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 31 Oct 2025 13:58:38 -0400 Subject: [PATCH 42/56] Updates to config and new config --- config.yaml | 18 +++++++------ config_demosaicing.yaml | 59 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 8 deletions(-) create mode 100644 config_demosaicing.yaml diff --git a/config.yaml b/config.yaml index 4dfd53a..0467e0c 100644 --- a/config.yaml +++ b/config.yaml @@ -3,7 +3,7 @@ base_data_dir: /Volumes/EasyStore/RAWNIND/ jpeg_output_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG cropped_jpeg_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG -cropped_raw_subdir: /Volumes/EasyStore/RAWNIND/Cropped_Raw +cropped_raw_subdir: /Users/ryanmueller/Develop/Cropped_Raw/ cropped_raw_size: 2000 align_csv: align_data.csv secondary_align_csv: align_phase_v2.csv @@ -17,18 +17,19 @@ batch_size: 2 crop_size: 256 lr_base: 2.5e-5 clipping: 1e-2 -num_epochs_pretraining: 50 +num_epochs_pretraining: 75 num_epochs_finetuning: 20 val_split: 0.2 random_seed: 42 - +cosine_annealing: False +iso_range: [0, 999999] # --- Experiment Settings --- experiment: NAF_test mlflow_experiment: NAFNet_variations # --- Run Configuration ---: -run_name: NAF_use_input_stats +run_name: NAF_gamma_test run_path: NAF_deep_test_align model_params: chans: [32, 64, 128, 256, 256, 256] @@ -44,8 +45,9 @@ model_params: use_CondFuserV3: False use_attnblock: False use_CondFuserV4: False - use_NAFBlock0_learned_norm: True - use_input_stats: True + use_NAFBlock0_learned_norm: False + use_input_stats: False + gamma: 2.2 # --- Loss Configureation ---: alpha: 0.2 @@ -53,5 +55,5 @@ beta: 5.0 l1_weight: 0.16 ssim_weight: 0.84 tv_weight: 0.0 -vgg_loss_weight: 0 -percept_loss_weight: 0 \ No newline at end of file +vgg_loss_weight: 0.0 +percept_loss_weight: 0.025 \ No newline at end of file diff --git a/config_demosaicing.yaml b/config_demosaicing.yaml new file mode 100644 index 0000000..6ecabe2 --- /dev/null +++ b/config_demosaicing.yaml @@ -0,0 +1,59 @@ +# config.yaml +# --- Paths --- +base_data_dir: /Volumes/EasyStore/RAWNIND/ +jpeg_output_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG +cropped_jpeg_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG +cropped_raw_subdir: /Users/ryanmueller/Develop/Cropped_Raw/ +cropped_raw_size: 2000 +align_csv: align_data.csv +secondary_align_csv: align_phase_v2.csv +script_path: /Volumes/EasyStore/models/traces +mlflow_path: /Volumes/EasyStore/models/mlfow + +# --- Training Params --- +colorspace: lin_rec2020 +device: mps +batch_size: 2 +crop_size: 256 +lr_base: 2.5e-5 +clipping: 1e-2 +num_epochs_pretraining: 75 +num_epochs_finetuning: 20 +val_split: 0.2 +random_seed: 42 +cosine_annealing: True +iso_range: [0, 999999] + +# --- Experiment Settings --- +experiment: Demosaicing +mlflow_experiment: Demosaicing + +# --- Run Configuration ---: +run_name: Demosaicing_test +run_path: Demosaicing_Test +model_params: + chans: [12, 24] + enc_blk_nums: [1] + middle_blk_num: 1 + dec_blk_nums: [1] + cond_input: 1 + in_channels: 4 + out_channels: 3 + rggb: True + use_CondFuserV2: False + use_add: False + use_CondFuserV3: False + use_attnblock: False + use_CondFuserV4: False + use_NAFBlock0_learned_norm: False + use_input_stats: False + gamma: 1.0 + +# --- Loss Configureation ---: +alpha: 0.2 +beta: 5.0 +l1_weight: 0.16 +ssim_weight: 0.84 +tv_weight: 0.0 +vgg_loss_weight: 0.0 +percept_loss_weight: 0.025 \ No newline at end of file From 2ff78ec30beee484dd71c8da5a4f8c9ae8d0247f Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 31 Oct 2025 16:32:58 -0400 Subject: [PATCH 43/56] Script to train on raw from scratch --- 1_train_model_raw.ipynb | 216 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 216 insertions(+) create mode 100644 1_train_model_raw.ipynb diff --git a/1_train_model_raw.ipynb b/1_train_model_raw.ipynb new file mode 100644 index 0000000..f68ce2e --- /dev/null +++ b/1_train_model_raw.ipynb @@ -0,0 +1,216 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6351e77", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader, random_split\n", + "import torch.nn as nn\n", + "import torch\n", + "import copy\n", + "import mlflow\n", + "import mlflow.pytorch\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2043dc8e", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.RawDatasetDNG import RawDatasetDNG\n", + "from src.training.losses.ShadowAwareLoss import ShadowAwareLoss\n", + "from src.training.VGGFeatureExtractor import VGGFeatureExtractor\n", + "from src.training.train_loop import train_one_epoch, visualize\n", + "from src.training.utils import apply_gamma_torch\n", + "from src.training.load_config import load_config\n", + "from src.Restorer.Cond_NAF import make_full_model_RGGB\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0464c98", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "dataset_path = Path(run_config['cropped_raw_subdir'])\n", + "align_csv = dataset_path / run_config['secondary_align_csv']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7322456f", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba20b866", + "metadata": {}, + "outputs": [], + "source": [ + "device=run_config['device']\n", + "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", + "\n", + "num_epochs = run_config['num_epochs_pretraining']\n", + "cosine_annealing = run_config['cosine_annealing']\n", + "\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "mlflow_path = run_config['mlflow_path']\n", + "colorspace = run_config['colorspace']\n", + "iso_range = run_config['iso_range']\n", + "\n", + "rggb = True\n", + "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", + "mlflow.set_experiment(experiment)\n", + "\n", + "params = {**run_config}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9a26124", + "metadata": {}, + "outputs": [], + "source": [ + "model_params = run_config['model_params']\n", + "rggb = model_params['rggb']\n", + "\n", + "model = make_full_model_RGGB(model_params, model_name=None)\n", + "model = model.to(device)\n", + "\n", + "params = {**run_config, **model_params}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f16fa7", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = RawDatasetDNG(dataset_path, align_csv, colorspace, crop_size=crop_size)\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('crw')]\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('dng_bayer')]\n", + "dataset.df = dataset.df[(dataset.df.iso >= iso_range[0]) & (dataset.df.iso <= iso_range[1])]\n", + "print(len(dataset.df ))\n", + "# Split dataset into train and val\n", + "val_size = int(len(dataset) * val_split)\n", + "train_size = len(dataset) - val_size\n", + "torch.manual_seed(42) # For reproducibility\n", + "train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n", + "# Set the validation dataset to use the same crops\n", + "val_dataset = copy.deepcopy(val_dataset)\n", + "val_dataset.dataset.validation = True\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6af0f3a2", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", + "if cosine_annealing:\n", + " sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs,eta_min=lr*1e-6)\n", + "else:\n", + " sched = None\n", + " \n", + "vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), \n", + " feature_layers=[14], \n", + " activation=nn.ReLU\n", + " )\n", + "vfe = vfe.to(device)\n", + "\n", + "loss_fn = ShadowAwareLoss(\n", + " alpha=run_config['alpha'],\n", + " beta=run_config['beta'],\n", + " l1_weight=run_config['l1_weight'],\n", + " ssim_weight=run_config['ssim_weight'],\n", + " tv_weight=run_config['tv_weight'],\n", + " vgg_loss_weight=run_config['vgg_loss_weight'],\n", + " percept_loss_weight=run_config['percept_loss_weight'],\n", + " apply_gamma_fn=apply_gamma_torch,\n", + " vgg_feature_extractor=vfe,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bc86a06", + "metadata": {}, + "outputs": [], + "source": [ + "with mlflow.start_run(run_name=run_config['run_name']) as run:\n", + " mlflow.log_params(params)\n", + " for epoch in range(num_epochs):\n", + " train_one_epoch(epoch, model, optimizer, train_loader, device, loss_fn, clipping, \n", + " log_interval = 10, sleep=0.0, rggb=rggb, max_batches=0)\n", + " if cosine_annealing:\n", + " sched.step()\n", + " \n", + " mlflow.pytorch.log_model(\n", + " pytorch_model=model,\n", + " name=run_config['run_path'],\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a775d31", + "metadata": {}, + "outputs": [], + "source": [ + "run.info.run_id" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From b4b0676494d20c7896136f8d3b98db85b5de44d4 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 31 Oct 2025 16:33:17 -0400 Subject: [PATCH 44/56] Improved validaiton script --- 1_validate_model.ipynb | 394 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 391 insertions(+), 3 deletions(-) diff --git a/1_validate_model.ipynb b/1_validate_model.ipynb index a909f15..b1400c0 100644 --- a/1_validate_model.ipynb +++ b/1_validate_model.ipynb @@ -42,7 +42,7 @@ "source": [ "run_config = load_config()\n", "dataset_path = Path(run_config['jpeg_output_subdir'])\n", - "align_csv = dataset_path / run_config['align_csv']\n", + "align_csv = dataset_path / run_config['secondary_align_csv']\n", "device=run_config['device']\n", "\n", "batch_size = run_config['batch_size']\n", @@ -54,7 +54,10 @@ "crop_size = run_config['crop_size']\n", "experiment = run_config['mlflow_experiment']\n", "model_params = run_config['model_params']\n", - "rggb = model_params['rggb']" + "rggb = model_params['rggb']\n", + "mlflow_path = run_config['mlflow_path']\n", + "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", + "mlflow.set_experiment(experiment)" ] }, { @@ -65,7 +68,7 @@ "outputs": [], "source": [ "\n", - "RUN_ID = \"66f11f4639f24e9ea75b4c953147be15\" \n", + "RUN_ID = \"b0664f324e9444d3b3a5277d513d3642\" \n", "ARTIFACT_PATH = run_config['run_path']\n", "\n", "model_uri = f\"runs:/{RUN_ID}/{ARTIFACT_PATH}\"\n", @@ -130,6 +133,20 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "51ad2666", + "metadata": {}, + "outputs": [], + "source": [ + "# with mlflow.start_run(run_id=existing_run_id) as run:\n", + "# print(f\"Re-opened run: {run.info.run_name}\")\n", + " \n", + "# # Log your new validation metric\n", + "# mlflow.log_metric(\"final_validation_accuracy\", new_validation_metric)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -153,6 +170,377 @@ "source": [ "visualize(matching_indices_in_subset, model, val_dataset, device, loss_fn, rggb=rggb)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "641d864b", + "metadata": {}, + "outputs": [], + "source": [ + "Train loss: 0.056005 Final image val loss: 0.009073 Time: 45.7s Images: 25\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "256f0de3", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "import mlflow\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from time import perf_counter\n", + "\n", + "\n", + "def make_conditioning(conditioning, device):\n", + " B = conditioning.shape[0]\n", + " conditioning_extended = torch.zeros(B, 1).to(device)\n", + " conditioning_extended[:, 0] = conditioning[:, 0]\n", + " return conditioning_extended\n", + "\n", + "\n", + "def validate_model(\n", + " _model, \n", + " val_dataset, \n", + " _device, \n", + " _loss_func, epoch,\n", + " iso_values=(100, 400, 1600, 65535), \n", + " n_examples=3, \n", + " rggb=False,\n", + " artifact_dir=\"val_examples\"\n", + "):\n", + " \"\"\"\n", + " Run validation over different ISO values and log metrics + sample images to MLflow.\n", + "\n", + " Args:\n", + " _model: torch model\n", + " val_dataset: dataset with `df.iso` values accessible\n", + " _device: torch device\n", + " _loss_func: loss function\n", + " iso_values: list/tuple of ISO values to evaluate\n", + " n_examples: number of images to log per ISO\n", + " rggb: whether to use rggb input\n", + " artifact_dir: subdirectory to store example images for MLflow\n", + " \"\"\"\n", + "\n", + " _model.eval()\n", + " os.makedirs(artifact_dir, exist_ok=True)\n", + "\n", + " all_metrics = {}\n", + "\n", + " for iso in iso_values:\n", + " # Select subset indices with this ISO\n", + " subset_indices = np.array(val_dataset.indices)\n", + " mask = val_dataset.dataset.df.iso.values[subset_indices] == iso\n", + " matching_indices_in_subset = np.nonzero(mask)[0]\n", + "\n", + " if len(matching_indices_in_subset) == 0:\n", + " print(f\"No validation samples for ISO {iso}\")\n", + " continue\n", + "\n", + "\n", + " idxs = matching_indices_in_subset[:n_examples]\n", + "\n", + " total_loss, final_img_loss, duration = visualize(\n", + " idxs, _model, val_dataset, _device, _loss_func, rggb=rggb\n", + " )\n", + "\n", + "\n", + " # Save and log a few example denoised images\n", + " for i, idx in enumerate(idxs):\n", + " row = val_dataset[idx]\n", + " noisy = row['noisy'].unsqueeze(0).float().to(_device)\n", + " conditioning = make_conditioning(row['conditioning'].float().unsqueeze(0).to(_device), _device)\n", + " gt = row['aligned'].unsqueeze(0).float().to(_device)\n", + " input = row['rggb' if rggb else 'sparse'].unsqueeze(0).float().to(_device)\n", + "\n", + " with torch.no_grad():\n", + " with torch.autocast(device_type=\"mps\", dtype=torch.bfloat16):\n", + " pred = _model(input, conditioning, noisy)\n", + "\n", + " pred_img = apply_gamma_torch(pred[0].cpu().permute(1, 2, 0))\n", + " noisy_img = apply_gamma_torch(noisy[0].cpu().permute(1, 2, 0))\n", + " gt_img = apply_gamma_torch(gt[0].cpu().permute(1, 2, 0))\n", + "\n", + " fig, axs = plt.subplots(1, 3, figsize=(15, 5))\n", + " axs[0].imshow(noisy_img)\n", + " axs[0].set_title(f\"Noisy (ISO {iso})\")\n", + " axs[1].imshow(pred_img)\n", + " axs[1].set_title(\"Denoised\")\n", + " axs[2].imshow(gt_img)\n", + " axs[2].set_title(\"Ground Truth\")\n", + " for ax in axs: ax.axis('off')\n", + "\n", + " img_path = os.path.join(artifact_dir, f\"iso{iso}_example{i}.png\")\n", + " plt.savefig(img_path, bbox_inches=\"tight\")\n", + " plt.close(fig)\n", + "\n", + " mlflow.log_artifact(img_path, artifact_path=f\"{artifact_dir}/ISO_{iso}\")\n", + "\n", + " # Log metrics per ISO\n", + " mlflow.log_metrics({\n", + " f\"val_loss_ISO_{iso}\": total_loss,\n", + " f\"final_img_loss_ISO_{iso}\": final_img_loss,\n", + " f\"val_duration_ISO_{iso}\": duration\n", + " }, step=epoch)\n", + "\n", + " # Also log a summary table\n", + " print(\"\\nValidation Summary:\")\n", + " for iso, metrics in all_metrics.items():\n", + " print(f\"ISO {iso}: val_loss={metrics['val_loss']:.6f}, \"\n", + " f\"final_img_loss={metrics['final_image_loss']:.6f}, \"\n", + " f\"time={metrics['duration']:.1f}s\")\n", + "\n", + " return all_metrics\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15ff7052", + "metadata": {}, + "outputs": [], + "source": [ + "with mlflow.start_run(run_name=run_config['run_name']) as run:\n", + " metrics = validate_model(model, val_dataset, device, loss_fn, rggb=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "db26581d", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "class SwiGLU(nn.Module):\n", + "\n", + " def __init__(self, input_dim, hidden_dim, dropout=0.1):\n", + " super().__init__()\n", + " self.w1 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1)\n", + " self.w2 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1)\n", + " self.w3 = nn.Conv2d(hidden_dim, input_dim, 1, 1, 0, 1)\n", + " \n", + " def forward(self, x):\n", + " gate = F.silu(self.w1(x)) \n", + " value = self.w2(x)\n", + " x = gate * value \n", + " \n", + " x = self.w3(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc387d19", + "metadata": {}, + "outputs": [], + "source": [ + "class ConditionedChannelAttention(nn.Module):\n", + " def __init__(self, dims, cat_dims):\n", + " super().__init__()\n", + " in_dim = dims + cat_dims\n", + " self.mlp = nn.Sequential(nn.Linear(in_dim, dims))\n", + " self.pool = nn.AdaptiveAvgPool2d(1)\n", + "\n", + " def forward(self, x, conditioning):\n", + " pool = self.pool(x)\n", + " conditioning = conditioning.unsqueeze(-1).unsqueeze(-1)\n", + " cat_channels = torch.cat([pool, conditioning], dim=1)\n", + " cat_channels = cat_channels.permute(0, 2, 3, 1)\n", + " ca = self.mlp(cat_channels).permute(0, 3, 1, 2)\n", + "\n", + " return ca" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86fdd311", + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn.functional as F\n", + "\n", + "class SwiGLU(nn.Module):\n", + "\n", + " def __init__(self, input_dim, hidden_dim, dropout=0.1):\n", + " super().__init__()\n", + " self.w1 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1)\n", + " self.w2 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1)\n", + " self.w3 = nn.Conv2d(hidden_dim, input_dim, 1, 1, 0, 1)\n", + " \n", + " def forward(self, x):\n", + " gate = F.silu(self.w1(x)) \n", + " value = self.w2(x)\n", + " x = gate * value \n", + " \n", + " x = self.w3(x)\n", + " return x\n", + " \n", + "class AttnBlock(nn.Module):\n", + " def __init__(self, c, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0):\n", + " super().__init__()\n", + " \n", + " self.dw = nn.Conv2d(\n", + " in_channels=c,\n", + " out_channels=c,\n", + " kernel_size=3,\n", + " padding=1,\n", + " stride=1,\n", + " groups=c,\n", + " bias=True,\n", + " )\n", + "\n", + " self.sca = ConditionedChannelAttention(c, cond_chans)\n", + "\n", + " self.norm = nn.GroupNorm(1, c)\n", + " \n", + " self.swiglu = SwiGLU(c, int(c * FFN_Expand))\n", + " self.alpha = nn.Parameter(torch.zeros(1, c, 1, 1))\n", + " self.beta = nn.Parameter(torch.zeros(1, c, 1, 1))\n", + "\n", + "\n", + " def forward(self, input):\n", + " inp = input[0]\n", + " cond = input[1]\n", + "\n", + " x = self.dw(inp)\n", + " x = self.sca(x, cond) * x\n", + " y = self.norm(inp + self.alpha * x )\n", + "\n", + "\n", + " x = self.swiglu(y)\n", + " x = y + self.beta * x\n", + " return (x, cond)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "990a5993", + "metadata": {}, + "outputs": [], + "source": [ + "block = AttnBlock(64, 2, cond_chans=1)\n", + "block((torch.rand(1, 64, 32, 32), torch.rand(1, 1)))[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b460abf", + "metadata": {}, + "outputs": [], + "source": [ + "class NAFBlock0(nn.Module):\n", + " def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0):\n", + " super().__init__()\n", + " dw_channel = c * DW_Expand\n", + " self.conv1 = nn.Conv2d(\n", + " in_channels=c,\n", + " out_channels=dw_channel,\n", + " kernel_size=1,\n", + " padding=0,\n", + " stride=1,\n", + " groups=1,\n", + " bias=True,\n", + " )\n", + " self.conv2 = nn.Conv2d(\n", + " in_channels=dw_channel,\n", + " out_channels=dw_channel,\n", + " kernel_size=3,\n", + " padding=1,\n", + " stride=1,\n", + " groups=dw_channel,\n", + " bias=True,\n", + " )\n", + " self.conv3 = nn.Conv2d(\n", + " in_channels=dw_channel // 2,\n", + " out_channels=c,\n", + " kernel_size=1,\n", + " padding=0,\n", + " stride=1,\n", + " groups=1,\n", + " bias=True,\n", + " )\n", + "\n", + " # Simplified Channel Attention\n", + " self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans)\n", + "\n", + " # SimpleGate\n", + " self.sg = SimpleGate()\n", + "\n", + " ffn_channel = FFN_Expand * c\n", + " self.conv4 = nn.Conv2d(\n", + " in_channels=c,\n", + " out_channels=ffn_channel,\n", + " kernel_size=1,\n", + " padding=0,\n", + " stride=1,\n", + " groups=1,\n", + " bias=True,\n", + " )\n", + " self.conv5 = nn.Conv2d(\n", + " in_channels=ffn_channel // 2,\n", + " out_channels=c,\n", + " kernel_size=1,\n", + " padding=0,\n", + " stride=1,\n", + " groups=1,\n", + " bias=True,\n", + " )\n", + "\n", + " # self.grn = GRN(ffn_channel // 2)\n", + "\n", + " self.norm1 = LayerNorm2d(c)\n", + " self.norm2 = LayerNorm2d(c)\n", + "\n", + " self.dropout1 = (\n", + " nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity()\n", + " )\n", + " self.dropout2 = (\n", + " nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity()\n", + " )\n", + "\n", + " self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)\n", + " self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)\n", + "\n", + " def forward(self, input):\n", + " inp = input[0]\n", + " cond = input[1]\n", + "\n", + " x = inp\n", + "\n", + " x = self.norm1(x)\n", + " x = self.conv1(x)\n", + " x = self.conv2(x)\n", + " x = self.sg(x)\n", + " x = x * self.sca(x, cond)\n", + " x = self.conv3(x)\n", + "\n", + " x = self.dropout1(x)\n", + "\n", + " y = inp + x * self.beta\n", + "\n", + " # Channel Mixing\n", + " x = self.conv4(self.norm2(y))\n", + " x = self.sg(x)\n", + " # x = self.grn(x)\n", + " x = self.conv5(x)\n", + "\n", + " x = self.dropout2(x)\n", + "\n", + " return (y + x * self.gamma, cond)" + ] } ], "metadata": { From 53aa027110baea483383d8be90765a3b2e17fd56 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 31 Oct 2025 16:33:51 -0400 Subject: [PATCH 45/56] Short script for data cleaning --- 1_visualize_alignment.ipynb | 279 ++++++++++++++++++++++++++++++++++++ 1 file changed, 279 insertions(+) create mode 100644 1_visualize_alignment.ipynb diff --git a/1_visualize_alignment.ipynb b/1_visualize_alignment.ipynb new file mode 100644 index 0000000..04bbf73 --- /dev/null +++ b/1_visualize_alignment.ipynb @@ -0,0 +1,279 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6351e77", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader, random_split\n", + "import copy\n", + "from pathlib import Path\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2043dc8e", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.SmallRawDatasetNumpy import SmallRawDatasetNumpy\n", + "from src.training.load_config import load_config\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b14af214", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "dataset_path = Path(run_config['cropped_raw_subdir'])\n", + "align_csv = dataset_path / run_config['secondary_align_csv']\n", + "device=run_config['device']\n", + "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", + "\n", + "num_epochs = run_config['num_epochs_pretraining']\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "model_params = run_config['model_params']\n", + "rggb = model_params['rggb']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f16fa7", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = SmallRawDatasetNumpy(dataset_path, align_csv, crop_size=crop_size, validation=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51ad2666", + "metadata": {}, + "outputs": [], + "source": [ + "# with mlflow.start_run(run_id=existing_run_id) as run:\n", + "# print(f\"Re-opened run: {run.info.run_name}\")\n", + " \n", + "# # Log your new validation metric\n", + "# mlflow.log_metric(\"final_validation_accuracy\", new_validation_metric)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae01182c", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4aa24966", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "from scipy.signal import correlate2d\n", + "from scipy.ndimage import gaussian_filter\n", + "from skimage.metrics import structural_similarity as ssim\n", + "\n", + "# --- Metric functions ---\n", + "\n", + "def structure_score_autocorr(img_np):\n", + " img = img_np - img_np.mean()\n", + " corr = correlate2d(img, img, mode='full')\n", + " corr /= corr.max() + 1e-8\n", + " center = np.array(corr.shape) // 2\n", + " corr[center[0], center[1]] = 0\n", + " return float(np.mean(corr))\n", + "\n", + "def fourier_entropy(img_np):\n", + " f = np.fft.fftshift(np.fft.fft2(img_np))\n", + " mag = np.abs(f)\n", + " mag = mag / (mag.sum() + 1e-8)\n", + " return float(-np.sum(mag * np.log(mag + 1e-12)))\n", + "\n", + "def laplacian_variance(img_np):\n", + " import cv2\n", + " lap = cv2.Laplacian(img_np, cv2.CV_64F)\n", + " return float(lap.var())\n", + "\n", + "def self_ssim(img_np, sigma=1.0):\n", + " blurred = gaussian_filter(img_np, sigma)\n", + " return float(ssim(img_np, blurred, data_range=img_np.max() - img_np.min()))\n", + "\n", + "# --- Visualization and analysis function ---\n", + "\n", + "def analyze_dataset_structure(dataset, indices, out_dir=\"structure_analysis\"):\n", + " import os\n", + " os.makedirs(out_dir, exist_ok=True)\n", + " records = []\n", + "\n", + " for idx in indices:\n", + " row = dataset[idx]\n", + " noisy = row['noisy']\n", + " gt = row['aligned']\n", + "\n", + " # Ensure CPU numpy grayscale (if multi-channel)\n", + " def to_gray(x):\n", + " if isinstance(x, torch.Tensor):\n", + " x = x.detach().cpu().numpy()\n", + " if x.ndim == 3 and x.shape[0] in [1,3]:\n", + " x = np.moveaxis(x, 0, -1)\n", + " if x.ndim == 3:\n", + " x = np.mean(x, axis=-1)\n", + " return np.clip(x, 0, 1)\n", + "\n", + " noisy_np = to_gray(noisy)\n", + " gt_np = to_gray(gt)\n", + " diff_np = np.abs(noisy_np - gt_np)\n", + "\n", + " # Compute metrics\n", + " metrics = {\n", + " \"idx\": idx,\n", + " \"fourier_entropy_noisy\": fourier_entropy(noisy_np),\n", + " \"lap_var_noisy\": laplacian_variance(noisy_np),\n", + " \"self_ssim_noisy\": self_ssim(noisy_np),\n", + " \"autocorr_noisy\": structure_score_autocorr(noisy_np),\n", + " \"fourier_entropy_diff\": fourier_entropy(diff_np),\n", + " \"lap_var_diff\": laplacian_variance(diff_np),\n", + " }\n", + " records.append(metrics)\n", + "\n", + " # Plot\n", + " fig, axes = plt.subplots(1, 3, figsize=(12, 4))\n", + " for ax, im, title in zip(axes, [noisy_np, gt_np, diff_np],\n", + " ['Noisy', 'Aligned', 'Difference']):\n", + " ax.imshow(im, cmap='gray')\n", + " ax.set_title(title)\n", + " ax.axis('off')\n", + "\n", + " txt = \"\\n\".join([\n", + " f\"Entropy: {metrics['fourier_entropy_noisy']:.3f}\",\n", + " f\"LapVar: {metrics['lap_var_noisy']:.3f}\",\n", + " f\"SelfSSIM: {metrics['self_ssim_noisy']:.3f}\",\n", + " f\"AutoCorr: {metrics['autocorr_noisy']:.3f}\",\n", + " ])\n", + " plt.gcf().text(0.75, 0.5, txt, fontsize=10, va='center')\n", + " plt.tight_layout()\n", + " plt.savefig(f\"{out_dir}/example_{idx:04d}.png\", dpi=150)\n", + " plt.close(fig)\n", + "\n", + " df = pd.DataFrame.from_records(records)\n", + " df.to_csv(f\"{out_dir}/structure_metrics.csv\", index=False)\n", + " print(f\"Saved {len(df)} examples and metrics to {out_dir}\")\n", + " return df\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56f5e484", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "from scipy.signal import correlate2d\n", + "from scipy.ndimage import gaussian_filter\n", + "from skimage.metrics import structural_similarity as ssim\n", + "\n", + "def analyze_dataset_structure(dataset, indices, out_dir=\"structure_analysis\"):\n", + " import os\n", + " os.makedirs(out_dir, exist_ok=True)\n", + " records = []\n", + "\n", + " for idx in indices:\n", + " row = dataset[idx]\n", + " noisy = row['noisy'].permute(1, 2, 0)**(1/2.2)\n", + " gt = row['aligned'].permute(1, 2, 0)**(1/2.2)\n", + " diff = noisy-gt +0.5\n", + " # Plot\n", + " fig, axes = plt.subplots(1, 3, figsize=(12, 4))\n", + " for ax, im, title in zip(axes, [noisy, gt, diff],\n", + " ['Noisy', 'Aligned', 'Difference']):\n", + " ax.imshow(im)\n", + " ax.set_title(title)\n", + " ax.axis('off')\n", + "\n", + " plt.tight_layout()\n", + " plt.show()\n", + " plt.close(fig)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59269994", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "mask = dataset.df.iso.values == 400\n", + "matching_indices_in_subset = np.nonzero(mask)[0]\n", + "matching_indices_in_subset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54c87086", + "metadata": {}, + "outputs": [], + "source": [ + "analyze_dataset_structure(dataset, matching_indices_in_subset,out_dir=\"structure_analysis_no_align\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "422b379e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From ffcce3a410669cccef0c04c511b2ad2675d821b3 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 31 Oct 2025 16:35:15 -0400 Subject: [PATCH 46/56] Code to finetune model on raw data --- ...el_raw.ipynb => 2_finetune_model_raw.ipynb | 60 ++++++++++++++----- 1 file changed, 44 insertions(+), 16 deletions(-) rename 2_pretrain_model_raw.ipynb => 2_finetune_model_raw.ipynb (76%) diff --git a/2_pretrain_model_raw.ipynb b/2_finetune_model_raw.ipynb similarity index 76% rename from 2_pretrain_model_raw.ipynb rename to 2_finetune_model_raw.ipynb index e20ea90..9a7d529 100644 --- a/2_pretrain_model_raw.ipynb +++ b/2_finetune_model_raw.ipynb @@ -69,16 +69,21 @@ "lr = run_config['lr_base'] * batch_size\n", "clipping = run_config['clipping']\n", "\n", - "num_epochs = run_config['num_epochs_pretraining']\n", + "num_epochs = run_config['num_epochs_finetuning']\n", + "cosine_annealing = run_config['cosine_annealing']\n", + "\n", "val_split = run_config['val_split']\n", "crop_size = run_config['crop_size']\n", "experiment = run_config['mlflow_experiment']\n", "mlflow_path = run_config['mlflow_path']\n", "colorspace = run_config['colorspace']\n", + "iso_range = run_config['iso_range']\n", "\n", "rggb = True\n", "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", - "mlflow.set_experiment(experiment)" + "mlflow.set_experiment(experiment)\n", + "\n", + "params = {**run_config}" ] }, { @@ -89,8 +94,10 @@ "outputs": [], "source": [ "\n", - "RUN_ID = \"eb0face7e288444e9aa20ffdb8ccdb23\" \n", + "RUN_ID = \"7d9ffb05e2c747fe93647e06ef43e51b\" \n", "ARTIFACT_PATH = run_config['run_path']\n", + "params['base_RUN_ID'] = RUN_ID\n", + "params['base_ARTIFACT_PATH'] = ARTIFACT_PATH\n", "\n", "model_uri = f\"runs:/{RUN_ID}/{ARTIFACT_PATH}\"\n", "\n", @@ -114,6 +121,7 @@ "dataset = RawDatasetDNG(dataset_path, align_csv, colorspace, crop_size=crop_size)\n", "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('crw')]\n", "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('dng_bayer')]\n", + "dataset.df = dataset.df[(dataset.df.iso >= iso_range[0]) & (dataset.df.iso <= iso_range[1])]\n", "\n", "# Split dataset into train and val\n", "val_size = int(len(dataset) * val_split)\n", @@ -136,7 +144,11 @@ "outputs": [], "source": [ "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", - "\n", + "if cosine_annealing:\n", + " sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs,eta_min=lr*1e-6)\n", + "else:\n", + " sched = None\n", + " \n", "vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), \n", " feature_layers=[14], \n", " activation=nn.ReLU\n", @@ -160,12 +172,34 @@ { "cell_type": "code", "execution_count": null, - "id": "c2bea1ed", + "id": "e7b3706e", "metadata": {}, "outputs": [], "source": [ - "from RawHandler.RawHandler import RawHandler\n", - "RawHandler('/Users/ryanmueller/Develop/Cropped_Raw/Bayer_TEST_Pen-pile_ISO400_sha1=fad719aac1eccd9a0e87409fd6c5249ab21aa60c.dng.dng')" + "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", + "if cosine_annealing:\n", + " sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs,eta_min=lr*1e-6)\n", + "else:\n", + " sched = None\n", + " \n", + "vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), \n", + " feature_layers=[14], \n", + " activation=nn.ReLU\n", + " )\n", + "vfe = vfe.to(device)\n", + "\n", + "loss_fn = ShadowAwareLoss(\n", + " alpha=run_config['alpha'],\n", + " beta=run_config['beta'],\n", + " l1_weight=run_config['l1_weight'],\n", + " ssim_weight=run_config['ssim_weight'],\n", + " tv_weight=run_config['tv_weight'],\n", + " vgg_loss_weight=run_config['vgg_loss_weight'],\n", + " percept_loss_weight=run_config['percept_loss_weight'],\n", + " apply_gamma_fn=apply_gamma_torch,\n", + " vgg_feature_extractor=vfe,\n", + " device=device,\n", + ")" ] }, { @@ -176,10 +210,12 @@ "outputs": [], "source": [ "with mlflow.start_run(run_name=run_config['run_name']) as run:\n", - "\n", + " mlflow.log_params(params)\n", " for epoch in range(num_epochs):\n", " train_one_epoch(epoch, model, optimizer, train_loader, device, loss_fn, clipping, \n", " log_interval = 10, sleep=0.0, rggb=rggb, max_batches=0)\n", + " if cosine_annealing:\n", + " sched.step()\n", " \n", " mlflow.pytorch.log_model(\n", " pytorch_model=model,\n", @@ -196,14 +232,6 @@ "source": [ "run.info.run_id" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bf1f13e1", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { From a2161aad5e17a1ae064d38bb1d7792753193682a Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 31 Oct 2025 16:36:20 -0400 Subject: [PATCH 47/56] Updating raw dataset to provided aligned raw data for noise characterization and simulation --- src/training/RawDatasetDNG.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/training/RawDatasetDNG.py b/src/training/RawDatasetDNG.py index 103d40f..6c9a162 100644 --- a/src/training/RawDatasetDNG.py +++ b/src/training/RawDatasetDNG.py @@ -59,6 +59,19 @@ def random_crop_dim(shape, crop_size, buffer, validation=False): right = left + crop_size return (left, right, top, bottom) +def check_align_matrix(row, tolerance=1e-7): + is_identity = np.isclose(row['M00'], 1.0, atol=tolerance) and \ + np.isclose(row['M01'], 0.0, atol=tolerance) and \ + np.isclose(row['M10'], 0.0, atol=tolerance) and \ + np.isclose(row['M11'], 1.0, atol=tolerance) + + assert is_identity, "Rotations, scalings, or shearing are not tested." + + +def round_to_nearest_2(number): + return round(number / 2) * 2 + + class RawDatasetDNG(Dataset): def __init__(self, path, csv, colorspace, crop_size=180, buffer=10, validation=False, run_align=False, dimensions=2000): super().__init__() @@ -99,11 +112,19 @@ def __getitem__(self, idx): noisy = noisy_rh.as_rgb(dims=dims, colorspace=self.colorspace) rggb = noisy_rh.as_rggb(dims=dims, colorspace=self.colorspace) sparse = noisy_rh.as_sparse(dims=dims, colorspace=self.colorspace) - + + check_align_matrix(row) expanded_dims = [dims[0]-self.buffer, dims[1]+self.buffer, dims[2]-self.buffer, dims[3]+self.buffer] gt_expanded = gt_rh.as_rgb(dims=expanded_dims, colorspace=self.colorspace) aligned = apply_alignment(gt_expanded.transpose(1, 2, 0), row.to_dict())[self.buffer:-self.buffer, self.buffer:-self.buffer] gt_non_aligned = gt_expanded.transpose(1, 2, 0)[self.buffer:-self.buffer, self.buffer:-self.buffer] + # # gt_non_aligned = gt_non_aligned * noisy.mean() / aligned.mean() + # # aligned = aligned * noisy.mean() / aligned.mean() + # # Get Raw data for testing + # noisy_raw = noisy_rh.raw[dims[0]:dims[1], dims[2]: dims[3]] + # row_dict = row.to_dict() + # shift_y, shift_x = round_to_nearest_2(row_dict['M12']), round_to_nearest_2(row_dict['M11']) + # gt_raw = gt_rh.raw[dims[0]+shift_y:dims[1]+shift_y, dims[2]+shift_x:dims[3]+shift_x] except: print(name, gt_name) @@ -118,6 +139,8 @@ def __getitem__(self, idx): "noisy": torch.tensor(noisy).to(float).clip(0,1), "rggb": torch.tensor(rggb).to(float).clip(0,1), "conditioning": torch.tensor([row.iso/self.coordinate_iso]).to(float), + # "noisy_raw": noisy_raw, + # "gt_raw": gt_raw, # "noise_est": noise_est, # "rggb_gt": rggb_gt, } From ff4506e9779482823d79fd930dff95a2b1824fe4 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 31 Oct 2025 21:52:56 -0400 Subject: [PATCH 48/56] Code to train demosaicing models --- 1_train_model_raw_demosaicing.ipynb | 282 ++++++++++++++++++++++++++ 1_validated_raw_demosaicing.ipynb | 303 ++++++++++++++++++++++++++++ 2 files changed, 585 insertions(+) create mode 100644 1_train_model_raw_demosaicing.ipynb create mode 100644 1_validated_raw_demosaicing.ipynb diff --git a/1_train_model_raw_demosaicing.ipynb b/1_train_model_raw_demosaicing.ipynb new file mode 100644 index 0000000..6543f21 --- /dev/null +++ b/1_train_model_raw_demosaicing.ipynb @@ -0,0 +1,282 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6351e77", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader, random_split\n", + "import torch.nn as nn\n", + "import torch\n", + "import copy\n", + "import mlflow\n", + "import mlflow.pytorch\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2043dc8e", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.DemosaicingDataset import DemosaicingDataset\n", + "from src.training.losses.ShadowAwareLoss import ShadowAwareLoss\n", + "from src.training.VGGFeatureExtractor import VGGFeatureExtractor\n", + "from src.training.train_loop import train_one_epoch, visualize\n", + "from src.training.utils import apply_gamma_torch\n", + "from src.training.load_config import load_config\n", + "from src.Restorer.Cond_NAF import make_full_model_RGGB_Demosaicing\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0464c98", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config('config_demosaicing.yaml')\n", + "dataset_path = Path(run_config['cropped_raw_subdir'])\n", + "align_csv = dataset_path / run_config['secondary_align_csv']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba20b866", + "metadata": {}, + "outputs": [], + "source": [ + "device=run_config['device']\n", + "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", + "\n", + "num_epochs = run_config['num_epochs_pretraining']\n", + "cosine_annealing = run_config['cosine_annealing']\n", + "\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "mlflow_path = run_config['mlflow_path']\n", + "colorspace = run_config['colorspace']\n", + "iso_range = run_config['iso_range']\n", + "\n", + "rggb = True\n", + "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", + "mlflow.set_experiment(experiment)\n", + "\n", + "params = {**run_config}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9a26124", + "metadata": {}, + "outputs": [], + "source": [ + "model_params = run_config['model_params']\n", + "rggb = model_params['rggb']\n", + "\n", + "model = make_full_model_RGGB_Demosaicing(model_params, model_name=None)\n", + "model = model.to(device)\n", + "\n", + "params = {**run_config, **model_params}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f16fa7", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = DemosaicingDataset(dataset_path, align_csv, colorspace, output_crop_size=crop_size, downsample_factor=4)\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('crw')]\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('dng_bayer')]\n", + "dataset.df = dataset.df[(dataset.df.iso >= iso_range[0]) & (dataset.df.iso <= iso_range[1])]\n", + "print(len(dataset.df ))\n", + "# Split dataset into train and val\n", + "val_size = int(len(dataset) * val_split)\n", + "train_size = len(dataset) - val_size\n", + "torch.manual_seed(42) # For reproducibility\n", + "train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n", + "# Set the validation dataset to use the same crops\n", + "val_dataset = copy.deepcopy(val_dataset)\n", + "val_dataset.dataset.validation = True\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6af0f3a2", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", + "if cosine_annealing:\n", + " sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs,eta_min=lr*1e-6)\n", + "else:\n", + " sched = None\n", + " \n", + "vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), \n", + " feature_layers=[14], \n", + " activation=nn.ReLU\n", + " )\n", + "vfe = vfe.to(device)\n", + "\n", + "loss_fn = ShadowAwareLoss(\n", + " alpha=run_config['alpha'],\n", + " beta=run_config['beta'],\n", + " l1_weight=run_config['l1_weight'],\n", + " ssim_weight=run_config['ssim_weight'],\n", + " tv_weight=run_config['tv_weight'],\n", + " vgg_loss_weight=run_config['vgg_loss_weight'],\n", + " percept_loss_weight=run_config['percept_loss_weight'],\n", + " apply_gamma_fn=apply_gamma_torch,\n", + " vgg_feature_extractor=vfe,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4616a32", + "metadata": {}, + "outputs": [], + "source": [ + "from time import perf_counter\n", + "import time\n", + "from tqdm import tqdm\n", + "import torch\n", + "import torch.nn as nn\n", + "from src.training.utils import apply_gamma_torch\n", + "import mlflow\n", + "\n", + "def make_conditioning(conditioning, device):\n", + " B = conditioning.shape[0]\n", + " conditioning_extended = torch.zeros(B, 1).to(device)\n", + " conditioning_extended[:, 0] = conditioning[:, 0]\n", + " return conditioning_extended\n", + "\n", + "\n", + "def train_one_epoch(epoch, _model, _optimizer, _loader, _device, _loss_func, _clipping, \n", + " log_interval = 10, sleep=0.0, rggb=False,\n", + " max_batches=0):\n", + " _model.train()\n", + " total_loss, n_images, total_l1_loss = 0.0, 0, 0.0\n", + " start = perf_counter()\n", + " pbar = tqdm(enumerate(_loader), total=len(_loader), desc=f\"Train Epoch {epoch}\")\n", + "\n", + " for batch_idx, (output) in pbar:\n", + " conditioning = output['conditioning'].float().to(_device)\n", + " gt = output['ground_truth'].float().to(_device)\n", + " input = output['cfa_sparse'].float().to(_device)\n", + " if rggb:\n", + " input = output['cfa_rggb'].float().to(_device)\n", + " conditioning = make_conditioning(conditioning, _device)\n", + "\n", + " _optimizer.zero_grad(set_to_none=True)\n", + " pred = _model(input, conditioning) \n", + "\n", + " loss = _loss_func(pred, gt)\n", + " _optimizer.zero_grad(set_to_none=True)\n", + " loss.backward()\n", + " torch.nn.utils.clip_grad_norm_(_model.parameters(), _clipping)\n", + " _optimizer.step()\n", + "\n", + " total_loss += float(loss.detach().cpu())\n", + " n_images += gt.shape[0]\n", + "\n", + " # Testing final image quality\n", + " final_image_loss = float(nn.functional.l1_loss(pred, gt).detach().cpu())\n", + " total_l1_loss += final_image_loss\n", + " del loss, pred, final_image_loss\n", + " torch.mps.empty_cache() \n", + "\n", + " if (batch_idx + 1) % log_interval == 0:\n", + " pbar.set_postfix({\"loss\": f\"{total_loss/n_images:.4f}\"})\n", + "\n", + " if (max_batches > 0) and (batch_idx+1 > max_batches): break\n", + " time.sleep(sleep)\n", + "\n", + " train_time = perf_counter()-start\n", + " print(f\"[Epoch {epoch}] \"\n", + " f\"Train loss: {total_loss/n_images:.6f} \"\n", + " f\"L1 loss: {total_l1_loss/n_images:.6f} \"\n", + " f\"Time: {train_time:.1f}s \"\n", + " f\"Images: {n_images}\")\n", + " mlflow.log_metric(\"train_loss\", total_loss/n_images, step=epoch)\n", + " mlflow.log_metric(\"l1_loss\", total_l1_loss/n_images, step=epoch)\n", + " mlflow.log_metric(\"epoch_duration_s\", train_time, step=epoch)\n", + " mlflow.log_metric(\"learning_rate\", _optimizer.param_groups[0]['lr'], step=epoch)\n", + "\n", + " return total_loss / max(1, n_images), perf_counter()-start" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bc86a06", + "metadata": {}, + "outputs": [], + "source": [ + "with mlflow.start_run(run_name=run_config['run_name']) as run:\n", + " mlflow.log_params(params)\n", + " for epoch in range(num_epochs):\n", + " train_one_epoch(epoch, model, optimizer, train_loader, device, loss_fn, clipping, \n", + " log_interval = 10, sleep=0.0, rggb=rggb, max_batches=0)\n", + " if cosine_annealing:\n", + " sched.step()\n", + " \n", + " mlflow.pytorch.log_model(\n", + " pytorch_model=model,\n", + " name=run_config['run_path'],\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a775d31", + "metadata": {}, + "outputs": [], + "source": [ + "run.info.run_id" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/1_validated_raw_demosaicing.ipynb b/1_validated_raw_demosaicing.ipynb new file mode 100644 index 0000000..b36bc93 --- /dev/null +++ b/1_validated_raw_demosaicing.ipynb @@ -0,0 +1,303 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6351e77", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader, random_split\n", + "import torch.nn as nn\n", + "import torch\n", + "import copy\n", + "import mlflow\n", + "import mlflow.pytorch\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2043dc8e", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.DemosaicingDataset import DemosaicingDataset\n", + "from src.training.losses.ShadowAwareLoss import ShadowAwareLoss\n", + "from src.training.VGGFeatureExtractor import VGGFeatureExtractor\n", + "from src.training.train_loop import train_one_epoch, visualize\n", + "from src.training.utils import apply_gamma_torch\n", + "from src.training.load_config import load_config\n", + "from src.Restorer.Cond_NAF import make_full_model_RGGB_Demosaicing\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0464c98", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config('config_demosaicing.yaml')\n", + "dataset_path = Path(run_config['cropped_raw_subdir'])\n", + "align_csv = dataset_path / run_config['secondary_align_csv']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba20b866", + "metadata": {}, + "outputs": [], + "source": [ + "device=run_config['device']\n", + "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", + "\n", + "num_epochs = run_config['num_epochs_pretraining']\n", + "cosine_annealing = run_config['cosine_annealing']\n", + "\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "mlflow_path = run_config['mlflow_path']\n", + "colorspace = run_config['colorspace']\n", + "iso_range = run_config['iso_range']\n", + "\n", + "rggb = True\n", + "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", + "mlflow.set_experiment(experiment)\n", + "\n", + "params = {**run_config}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9a26124", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "RUN_ID = \"1e253a47ff7d43e1a432cce4ed083c7b\" \n", + "ARTIFACT_PATH = run_config['run_path']\n", + "\n", + "model_uri = f\"runs:/{RUN_ID}/{ARTIFACT_PATH}\"\n", + "\n", + "try:\n", + " model = mlflow.pytorch.load_model(model_uri)\n", + " model.eval()\n", + " print(f\"Model successfully loaded from MLflow URI: {model_uri}\")\n", + " \n", + "\n", + "except Exception as e:\n", + " print(f\"Error loading model from MLflow: {e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f16fa7", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = DemosaicingDataset(dataset_path, align_csv, colorspace, output_crop_size=crop_size, downsample_factor=4)\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('crw')]\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('dng_bayer')]\n", + "dataset.df = dataset.df[(dataset.df.iso >= iso_range[0]) & (dataset.df.iso <= iso_range[1])]\n", + "print(len(dataset.df ))\n", + "# Split dataset into train and val\n", + "val_size = int(len(dataset) * val_split)\n", + "train_size = len(dataset) - val_size\n", + "torch.manual_seed(42) # For reproducibility\n", + "train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n", + "# Set the validation dataset to use the same crops\n", + "val_dataset = copy.deepcopy(val_dataset)\n", + "val_dataset.dataset.validation = True\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6af0f3a2", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", + "if cosine_annealing:\n", + " sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs,eta_min=lr*1e-6)\n", + "else:\n", + " sched = None\n", + " \n", + "vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), \n", + " feature_layers=[14], \n", + " activation=nn.ReLU\n", + " )\n", + "vfe = vfe.to(device)\n", + "\n", + "loss_fn = ShadowAwareLoss(\n", + " alpha=run_config['alpha'],\n", + " beta=run_config['beta'],\n", + " l1_weight=run_config['l1_weight'],\n", + " ssim_weight=run_config['ssim_weight'],\n", + " tv_weight=run_config['tv_weight'],\n", + " vgg_loss_weight=run_config['vgg_loss_weight'],\n", + " percept_loss_weight=run_config['percept_loss_weight'],\n", + " apply_gamma_fn=apply_gamma_torch,\n", + " vgg_feature_extractor=vfe,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4616a32", + "metadata": {}, + "outputs": [], + "source": [ + "from time import perf_counter\n", + "import time\n", + "from tqdm import tqdm\n", + "import torch\n", + "import torch.nn as nn\n", + "from src.training.utils import apply_gamma_torch\n", + "import mlflow\n", + "\n", + "def make_conditioning(conditioning, device):\n", + " B = conditioning.shape[0]\n", + " conditioning_extended = torch.zeros(B, 1).to(device)\n", + " conditioning_extended[:, 0] = conditioning[:, 0]\n", + " return conditioning_extended\n", + "\n", + "\n", + "from colour_demosaicing import (\n", + " mosaicing_CFA_Bayer,\n", + " demosaicing_CFA_Bayer_Menon2007\n", + ")\n", + "def visualize(idxs, _model, dataset, _device, _loss_func, rggb=False):\n", + " import matplotlib.pyplot as plt\n", + " _model.train()\n", + " total_loss, n_images, total_final_image_loss = 0.0, 0, 0.0\n", + " start = perf_counter()\n", + "\n", + " for idx in idxs:\n", + " row = dataset[idx]\n", + " conditioning = row['conditioning'].unsqueeze(0).float().to(_device)\n", + " gt = row['ground_truth'].unsqueeze(0).float().to(_device)\n", + " sparse = row['cfa_sparse'].unsqueeze(0).float().to(_device)\n", + " input = sparse\n", + " if rggb:\n", + " input = row['cfa_rggb'].unsqueeze(0).float().to(_device)\n", + "\n", + " conditioning = make_conditioning(conditioning, _device)\n", + " \n", + " with torch.no_grad():\n", + " with torch.autocast(device_type=\"mps\", dtype=torch.bfloat16):\n", + " pred = _model(input, conditioning) \n", + " loss = _loss_func(pred, gt)\n", + "\n", + " total_loss += float(loss.detach().cpu())\n", + " n_images += gt.shape[0]\n", + "\n", + " # Testing final image quality\n", + " final_image_loss = nn.functional.l1_loss(pred, gt)\n", + " total_final_image_loss += final_image_loss.item()\n", + "\n", + " plt.subplots(2, 3, figsize=(30, 15))\n", + "\n", + " plt.subplot(2, 3, 1)\n", + " plt.title('pred')\n", + " pred = apply_gamma_torch(pred[0].cpu().permute(1, 2, 0))\n", + " plt.imshow(pred)\n", + "\n", + " plt.subplot(2, 3, 2)\n", + " plt.title('Menon')\n", + " bayer = sparse[0].sum(axis=0)\n", + " bayer = apply_gamma_torch(bayer).cpu().numpy()\n", + " trad_demosaiced = demosaicing_CFA_Bayer_Menon2007(bayer)\n", + " plt.imshow(trad_demosaiced)\n", + "\n", + " plt.subplot(2, 3, 3)\n", + " plt.title('gt')\n", + "\n", + " gt = apply_gamma_torch(gt[0].cpu().permute(1, 2, 0))\n", + " plt.imshow(gt)\n", + "\n", + " plt.subplot(2, 3, 4)\n", + " plt.imshow(pred - gt + 0.5)\n", + "\n", + "\n", + " plt.subplot(2, 3, 5)\n", + " plt.imshow(trad_demosaiced - pred.cpu().numpy() + 0.5)\n", + "\n", + " plt.subplot(2, 3, 6)\n", + " plt.imshow(trad_demosaiced - gt.cpu().numpy() + 0.5)\n", + " plt.show()\n", + " plt.clf()\n", + "\n", + " n_images = len(idxs)\n", + " print(\n", + " f\"Train loss: {total_loss/n_images:.6f} \"\n", + " f\"Final image val loss: {total_final_image_loss/n_images:.6f} \"\n", + " f\"Time: {perf_counter()-start:.1f}s \"\n", + " f\"Images: {n_images}\")\n", + "\n", + " return total_loss / max(1, n_images), total_final_image_loss / max(1, n_images), perf_counter()-start" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5b87533", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "subset_indices = np.array(val_dataset.indices) # indices in the original dataset\n", + "mask = val_dataset.dataset.df.iso.values[subset_indices] == 200\n", + "matching_indices_in_subset = np.nonzero(mask)[0]\n", + "matching_indices_in_subset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cde66163", + "metadata": {}, + "outputs": [], + "source": [ + "visualize(matching_indices_in_subset, model, val_dataset, device, loss_fn, rggb=rggb)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 13147eb4f1a1727fb86ab18e558e472df43c7362 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 31 Oct 2025 21:53:19 -0400 Subject: [PATCH 49/56] Update to demosaicing dataset --- src/training/DemosaicingDataset.py | 250 +++++++++++++++++++++++++++++ 1 file changed, 250 insertions(+) create mode 100644 src/training/DemosaicingDataset.py diff --git a/src/training/DemosaicingDataset.py b/src/training/DemosaicingDataset.py new file mode 100644 index 0000000..3c78971 --- /dev/null +++ b/src/training/DemosaicingDataset.py @@ -0,0 +1,250 @@ +import pandas as pd +import os +from torch.utils.data import Dataset +import imageio +from colour_demosaicing import ( + mosaicing_CFA_Bayer +) +import numpy as np +import torch +import torch.nn.functional as F +from pathlib import Path +from RawHandler.RawHandler import RawHandler +from src.training.utils import cfa_to_sparse + +def random_crop_dim(shape, crop_size, buffer, validation=False): + """ + Calculates random (or centered) crop dimensions, ensuring even coordinates. + """ + h, w = shape + + # Ensure crop size is even + crop_size = (crop_size // 2) * 2 + + if not validation: + top = np.random.randint(0 + buffer, h - crop_size - buffer) + left = np.random.randint(0 + buffer, w - crop_size - buffer) + else: + top = (h - crop_size) // 2 + left = (w - crop_size) // 2 + + # Ensure top-left corner is even for correct Bayer pattern alignment + if top % 2 != 0: + top -= 1 + if left % 2 != 0: + left -= 1 + + # Handle potential boundary issues after adjustment + top = max(0, top) + left = max(0, left) + if top + crop_size > h: + top = h - crop_size + if left + crop_size > w: + left = w - crop_size + + bottom = top + crop_size + right = left + crop_size + return (left, right, top, bottom) + +# def cfa_to_sparse(cfa_image, pattern='RGGB'): +# """ +# Converts a 2D CFA image (H, W) to a 3D sparse RGB image (H, W, 3). +# """ +# H, W = cfa_image.shape +# sparse_rgb = np.zeros((H, W, 3), dtype=cfa_image.dtype) + +# # +# if pattern == 'RGGB': +# # R +# sparse_rgb[0::2, 0::2, 0] = cfa_image[0::2, 0::2] +# # G (top-right) +# sparse_rgb[0::2, 1::2, 1] = cfa_image[0::2, 1::2] +# # G (bottom-left) +# sparse_rgb[1::2, 0::2, 1] = cfa_image[1::2, 0::2] +# # B +# sparse_rgb[1::2, 1::2, 2] = cfa_image[1::2, 1::2] +# elif pattern == 'GRBG': +# sparse_rgb[0::2, 0::2, 1] = cfa_image[0::2, 0::2] +# sparse_rgb[0::2, 1::2, 0] = cfa_image[0::2, 1::2] +# sparse_rgb[1::2, 0::2, 2] = cfa_image[1::2, 0::2] +# sparse_rgb[1::2, 1::2, 1] = cfa_image[1::2, 1::2] +# elif pattern == 'GBRG': +# sparse_rgb[0::2, 0::2, 1] = cfa_image[0::2, 0::2] +# sparse_rgb[0::2, 1::2, 2] = cfa_image[0::2, 1::2] +# sparse_rgb[1::2, 0::2, 0] = cfa_image[1::2, 0::2] +# sparse_rgb[1::2, 1::2, 1] = cfa_image[1::2, 1::2] +# elif pattern == 'BGGR': +# sparse_rgb[0::2, 0::2, 2] = cfa_image[0::2, 0::2] +# sparse_rgb[0::2, 1::2, 1] = cfa_image[0::2, 1::2] +# sparse_rgb[1::2, 0::2, 1] = cfa_image[1::2, 0::2] +# sparse_rgb[1::2, 1::2, 0] = cfa_image[1::2, 1::2] +# else: +# raise NotImplementedError(f"Pattern {pattern} not implemented") + +# return sparse_rgb + +def cfa_to_rggb_stack(cfa_image, pattern='RGGB'): + """ + Converts a (H, W) CFA image to an (4, H/2, W/2) RGGB stack. + """ + assert cfa_image.ndim == 2, "Input must be (H, W)" + H, W = cfa_image.shape + assert H % 2 == 0 and W % 2 == 0, "Height and width must be even" + + if pattern == 'RGGB': + R = cfa_image[0::2, 0::2] + G1 = cfa_image[0::2, 1::2] # G at top-right + G2 = cfa_image[1::2, 0::2] # G at bottom-left + B = cfa_image[1::2, 1::2] + elif pattern == 'GRBG': + G1 = cfa_image[0::2, 0::2] + R = cfa_image[0::2, 1::2] + B = cfa_image[1::2, 0::2] + G2 = cfa_image[1::2, 1::2] + elif pattern == 'GBRG': + G1 = cfa_image[0::2, 0::2] + B = cfa_image[0::2, 1::2] + R = cfa_image[1::2, 0::2] + G2 = cfa_image[1::2, 1::2] + elif pattern == 'BGGR': + B = cfa_image[0::2, 0::2] + G1 = cfa_image[0::2, 1::2] + G2 = cfa_image[1::2, 0::2] + R = cfa_image[1::2, 1::2] + else: + raise NotImplementedError(f"Pattern {pattern} not implemented") + + # Stack R, G1, G2, B + rggb_stack = np.stack([R, G1, G2, B], axis=0) + return rggb_stack + +def pixel_unshuffle(x, r): + C, H, W = x.shape + x = ( + x.reshape(C, H // r, r, W // r, r) + .transpose(0, 2, 4, 1, 3) + .reshape(C * r**2, H // r, W // r) + ) + return x + +class DemosaicingDataset(Dataset): + """ + Dataset for learned demosaicing. + + Workflow: + 1. Load High-Res Noisy DNG. + 2. Crop a large patch, get its RGB representation (e.g., from RawHandler). + 3. Downsample this RGB patch (area) to create the Ground Truth image. + 4. Apply Bayer mosaicing to the Ground Truth to create the network input. + 5. Provide input in sparse (3-ch) and RGGB (4-ch) formats. + """ + def __init__(self, path, csv, colorspace, + output_crop_size=256, + downsample_factor=2, + buffer=10, + validation=False, + bayer_pattern='RGGB'): + super().__init__() + self.df = pd.read_csv(csv) + self.path = Path(path) + self.output_crop_size = (output_crop_size // 2) * 2 # Ensure even + self.downsample_factor = downsample_factor + self.input_crop_size = self.output_crop_size * self.downsample_factor + self.buffer = buffer + self.coordinate_iso = 6400.0 # Normalization constant for ISO + self.validation = validation + self.dtype = np.float32 # Use float32 for tensors + self.colorspace = colorspace + self.bayer_pattern = bayer_pattern + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + + try: + name = Path(f"{row.bayer_path}").name + name = str(self.path / name.replace('_bayer.jpg', '.dng')) + noisy_rh = RawHandler(name) + except Exception as e: + print(f"Error loading {name}: {e}") + return self.__getitem__((idx + 1) % len(self)) # Skip bad file + + try: + dims = random_crop_dim( + noisy_rh.raw.shape, + self.input_crop_size, + self.buffer, + validation=self.validation + ) + + # Get (3, H_in, W_in) RGB patch from RawHandler + noisy_rgb_full = noisy_rh.as_rgb( + dims=dims, + colorspace=self.colorspace + ).astype(self.dtype) + + # 3. Downsample to create Ground Truth + # Convert to (1, 3, H_in, W_in) tensor for interpolation + noisy_tensor = torch.from_numpy(noisy_rgb_full).unsqueeze(0) + + gt_tensor = F.interpolate( + noisy_tensor, + scale_factor=1.0/self.downsample_factor, + mode='area', + recompute_scale_factor=False # Use scale_factor directly + ) + + # Back to (3, H_out, W_out) and clip + gt_tensor = gt_tensor.squeeze(0).clip(0, 1) + + # 4. Apply Bayer mosaicing to create network input + # Convert GT to (H_out, W_out, 3) numpy array + gt_numpy = gt_tensor.permute(1, 2, 0).numpy() + + # Create (H_out, W_out) 2D CFA + input_cfa_hw = mosaicing_CFA_Bayer( + gt_numpy, + pattern=self.bayer_pattern + ) + + # 5. Provide input in sparse and RGGB formats + + # Create (H_out, W_out, 3) sparse array + input_sparse_chw = cfa_to_sparse( + input_cfa_hw, + pattern=self.bayer_pattern + )[0] + # Convert to (3, H_out, W_out) tensor + input_sparse_chw = torch.from_numpy( + input_sparse_chw + ).to(torch.float32) + + + # Create (4, H_out/2, W_out/2) RGGB stack + input_cfa_hw_expanded = np.expand_dims(input_cfa_hw, 0) + input_rggb_stack = pixel_unshuffle(input_cfa_hw_expanded, 2) + # input_rggb_stack = cfa_to_rggb_stack( + # input_cfa_hw, + # pattern=self.bayer_pattern + # ) + # Convert to (4, H_out/2, W_out/2) tensor + input_rggb_tensor = torch.from_numpy(input_rggb_stack).to(torch.float32) + + # 6. Conditioning tensor + conditioning = torch.tensor( + [row.iso / self.coordinate_iso] + ).to(torch.float32) + + output = { + "ground_truth": gt_tensor, # (3, H_out, W_out) + "cfa_sparse": input_sparse_chw, # (3, H_out, W_out) + "cfa_rggb": input_rggb_tensor, # (4, H_out/2, W_out/2) + "conditioning": conditioning # (1,) + } + return output + + except Exception as e: + print(f"Error processing {name} (idx {idx}): {e}") + return self.__getitem__((idx + 1) % len(self)) # Skip bad file From ca48e44758d2fde97c1e877627770ffce7134db1 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Fri, 31 Oct 2025 21:53:54 -0400 Subject: [PATCH 50/56] Removed comments --- src/training/DemosaicingDataset.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/training/DemosaicingDataset.py b/src/training/DemosaicingDataset.py index 3c78971..dfed023 100644 --- a/src/training/DemosaicingDataset.py +++ b/src/training/DemosaicingDataset.py @@ -225,10 +225,7 @@ def __getitem__(self, idx): # Create (4, H_out/2, W_out/2) RGGB stack input_cfa_hw_expanded = np.expand_dims(input_cfa_hw, 0) input_rggb_stack = pixel_unshuffle(input_cfa_hw_expanded, 2) - # input_rggb_stack = cfa_to_rggb_stack( - # input_cfa_hw, - # pattern=self.bayer_pattern - # ) + # Convert to (4, H_out/2, W_out/2) tensor input_rggb_tensor = torch.from_numpy(input_rggb_stack).to(torch.float32) From 430d9f2adbcbc1dc344826733c1117922cd5008a Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Sat, 1 Nov 2025 15:38:55 -0400 Subject: [PATCH 51/56] Basic demosaicing code --- src/Restorer/Cond_NAF_demosaic.py | 939 ++++++++++++++++++++++++++++++ src/Restorer/debayering.py | 381 ++++++++++++ 2 files changed, 1320 insertions(+) create mode 100644 src/Restorer/Cond_NAF_demosaic.py create mode 100644 src/Restorer/debayering.py diff --git a/src/Restorer/Cond_NAF_demosaic.py b/src/Restorer/Cond_NAF_demosaic.py new file mode 100644 index 0000000..ac2089d --- /dev/null +++ b/src/Restorer/Cond_NAF_demosaic.py @@ -0,0 +1,939 @@ +import torch.nn.functional as F +import torch +import torch.nn as nn +from src.Restorer.debayering import Debayer5x5 + + +class LayerNorm2dAdjusted(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x, target_mu, target_var): + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + + y = (x - mu) / torch.sqrt(var + self.eps) + + y = y * torch.sqrt(target_var + self.eps) + target_mu + + weight_view = self.weight.view(1, self.weight.size(0), 1, 1) + bias_view = self.bias.view(1, self.bias.size(0), 1, 1) + + y = weight_view * y + bias_view + return y + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + + y = (x - mu) / torch.sqrt(var + self.eps) + + weight_view = self.weight.view(1, self.weight.size(0), 1, 1) + bias_view = self.bias.view(1, self.bias.size(0), 1, 1) + + y = weight_view * y + bias_view + return y + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class ConditionedChannelAttention(nn.Module): + def __init__(self, dims, cat_dims): + super().__init__() + in_dim = dims + cat_dims + self.mlp = nn.Sequential(nn.Linear(in_dim, dims)) + self.pool = nn.AdaptiveAvgPool2d(1) + + def forward(self, x, conditioning): + pool = self.pool(x) + conditioning = conditioning.unsqueeze(-1).unsqueeze(-1) + cat_channels = torch.cat([pool, conditioning], dim=1) + cat_channels = cat_channels.permute(0, 2, 3, 1) + ca = self.mlp(cat_channels).permute(0, 3, 1, 2) + + return ca + +class CondFuser(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + # self.spa = nn.Conv2d( + # in_channels=chan * 2, + # out_channels=1, + # kernel_size=3, + # padding=1, + # stride=1, + # groups=1, + # bias=True, + # ) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + # spa = torch.sigmoid(self.spa(x)) + + x1, x2 = x.chunk(2, dim=1) + # return x1 * spa + x2 * (1 - spa) + return x1 + x2 + + +class NKA(nn.Module): + def __init__(self, dim, channel_reduction = 8): + super().__init__() + + reduced_channels = dim // channel_reduction + self.proj_1 = nn.Conv2d(dim, reduced_channels, 1, 1, 0) + self.dwconv = nn.Conv2d(reduced_channels, reduced_channels, 3, 1, 1, groups=reduced_channels) + self.proj_2 = nn.Conv2d(reduced_channels, reduced_channels * 2, 1, 1, 0) + self.sg = SimpleGate() + self.attention = nn.Conv2d(reduced_channels, dim, 1, 1, 0) + + def forward(self, x): + B, C, H, W = x.shape + # First projection to a smaller dimension + y = self.proj_1(x) + # DW conv + attn = self.dwconv(y) + # PW to increase channel count for SG + attn = self.proj_2(attn) + # Non-linearity + attn = self.sg(attn) + # Back to original dimensions + out = x * self.attention(attn) + return out + + +class CondFuser(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + x1, x2 = x.chunk(2, dim=1) + return x1 + x2 + +class CondFuserAdd(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + + def forward(self, x1, x2, cond): + return x1 + x2 + +class CondFuserV2(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + self.spa = NKA(chan * 2) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + spa = torch.sigmoid(self.spa(x)) * x + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class CondFuserV3(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + self.spa = nn.Conv2d( + in_channels=chan * 2, + out_channels=1, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + spa = torch.sigmoid(self.spa(x)) + + x1, x2 = x.chunk(2, dim=1) + return x1 * spa + x2 * (1 - spa) + +class CondFuserV4(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + self.pw = nn.Conv2d(chan * 2, chan, 1, stride=1, padding=0, groups=1) + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + x = self.pw(x) + return x + + +class NAFBlock0(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + x = self.conv4(self.norm2(y)) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return (y + x * self.gamma, cond) + + +class NAFBlock0_learned_norm(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.sca_mul = ConditionedChannelAttention(c, cond_chans) + self.sca_add = ConditionedChannelAttention(c, cond_chans) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + normed = self.norm2(y) + + # Input mediated channel attention, obstensibly to mitigate the effects of group norm on flat scenes + x = (1 + self.sca_mul(inp, cond)) * normed + self.sca_add(inp, cond) + + x = self.conv4(x) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + + + return (y + x * self.gamma, cond) + +class NAFBlock0AdjustedNorm(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2dAdjusted(c) + self.norm2 = LayerNorm2dAdjusted(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.sca_mul = ConditionedChannelAttention(c, cond_chans) + self.sca_add = ConditionedChannelAttention(c, cond_chans) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x, mu, var) + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + normed = self.norm2(y, mu, var) + + # Input mediated channel attention, obstensibly to mitigate the effects of group norm on flat scenes + # x = (1 + self.sca_mul(inp, cond)) * normed + self.sca_add(inp, cond) + + x = self.conv4(normed) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return (y + x * self.gamma, cond, mu, var) + + +import torch.nn.functional as F + +class SwiGLU(nn.Module): + + def __init__(self, input_dim, hidden_dim, dropout=0.1): + super().__init__() + self.w1 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1) + self.w2 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1) + self.w3 = nn.Conv2d(hidden_dim, input_dim, 1, 1, 0, 1) + + def forward(self, x): + gate = F.silu(self.w1(x)) + value = self.w2(x) + x = gate * value + + x = self.w3(x) + return x + +class AttnBlock(nn.Module): + def __init__(self, c, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + + self.dw = nn.Conv2d( + in_channels=c, + out_channels=c, + kernel_size=3, + padding=1, + stride=1, + groups=c, + bias=True, + ) + self.nka = NKA(c) + + self.sca = ConditionedChannelAttention(c, cond_chans) + + self.norm = nn.GroupNorm(1, c) + + self.swiglu = SwiGLU(c, int(c * FFN_Expand)) + self.alpha = nn.Parameter(torch.zeros(1, c, 1, 1)) + self.beta = nn.Parameter(torch.zeros(1, c, 1, 1)) + + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = self.dw(inp) + x = self.nka(x) + x = self.sca(x, cond) * x + y = self.norm(inp + self.alpha * x ) + + + x = self.swiglu(y) + x = y + self.beta * x + return (x, cond) + + +class CondSEBlock(nn.Module): + def __init__(self, chan, reduction=16, cond_chan=1): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(chan + cond_chan, chan // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(chan // reduction, chan, bias=False), + nn.Sigmoid() + ) + + def forward(self, x, conditioning): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = torch.cat([y, conditioning], dim=1) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + + +class ConditioningCNN(nn.Module): + def __init__(self, in_channels=4, num_logits=128): + """ + Args: + in_channels (int): Number of input channels (e.g., 3 for RGB). + num_logits (int): The desired size of the output 1D logit vector. + """ + super().__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True), + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True) + ) + + self.logit_head = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(256, num_logits) + ) + def forward(self, x): + x = self.encoder(x) + x = self.logit_head(x) + return x + +class Restorer(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[], + chans = [], + cond_input=1, + cond_output=32, + expand_dims=2, + drop_out_rate=0.0, + drop_out_rate_increment=0.0, + rggb = False, + use_CondFuserV2 = False, + use_add = False, + use_CondFuserV3 = False, + use_CondFuserV4 = False, + use_attnblock = False, + use_NAFBlock0_learned_norm=False, + use_cond_net = False, + cond_net_num = 32, + use_input_stats=False, + use_NAFBlock0AdjustedNorm=False, + ): + super().__init__() + if use_attnblock: + block = AttnBlock + elif use_NAFBlock0_learned_norm: + block = NAFBlock0_learned_norm + elif use_NAFBlock0AdjustedNorm: + block = NAFBlock0AdjustedNorm + else: + block = NAFBlock0 + + width = chans[0] + + self.expand_dims = expand_dims + self.conditioning_gen = nn.Sequential( + nn.Linear(cond_input, 64), nn.ReLU(), nn.Dropout(drop_out_rate), nn.Linear(64, cond_output), + ) + + + self.use_cond_net = use_cond_net + if use_cond_net: + self.cond_net = ConditioningCNN(in_channels=in_channels, num_logits=cond_net_num) + cond_output = cond_output + cond_net_num + + self.use_input_stats = use_input_stats + if use_input_stats: + cond_output = cond_output + in_channels * 2 + + self.rggb = rggb + if not rggb: + self.intro = nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + else: + self.intro = nn.Sequential( + + nn.Conv2d( + in_channels=in_channels, + out_channels=width * 2 ** 2, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ), + nn.PixelShuffle(2) + ) + + nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + self.ending = nn.Conv2d( + in_channels=width, + out_channels=out_channels, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + self.merges = nn.ModuleList() + + + # for num in enc_blk_nums: + for i in range(len(enc_blk_nums)): + current_chan = chans[i] + next_chan = chans[i + 1] + num = enc_blk_nums[i] + self.encoders.append( + nn.Sequential( + *[ + block(current_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + drop_out_rate += drop_out_rate_increment + self.downs.append(nn.Conv2d(current_chan, next_chan, 2, 2)) + + self.middle_blks = nn.Sequential( + *[ + block(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(middle_blk_num) + ] + ) + + for i in range(len(dec_blk_nums)): + current_chan = chans[-i-1] + next_chan = chans[-i-2] + num = dec_blk_nums[i] + self.ups.append( + nn.Sequential( + nn.Conv2d(current_chan, next_chan * 2 ** 2, 1, bias=False), nn.PixelShuffle(2) + ) + ) + drop_out_rate -= drop_out_rate_increment + if use_CondFuserV2: + self.merges.append(CondFuserV2(next_chan, cond_chan=cond_output)) + elif use_add: + self.merges.append(CondFuserAdd(next_chan, cond_chan=cond_output)) + elif use_CondFuserV3: + self.merges.append(CondFuserV3(next_chan, cond_chan=cond_output)) + elif use_CondFuserV4: + self.merges.append(CondFuserV4(next_chan, cond_chan=cond_output)) + else: + self.merges.append(CondFuser(next_chan, cond_chan=cond_output)) + + self.decoders.append( + nn.Sequential( + *[ + block(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + + + self.padder_size = 2 ** len(self.encoders) + + self.alpha = nn.Parameter(torch.zeros(1, 1, 1, 1)) + + def forward(self, inp, cond_in): + # Conditioning: + cond = self.conditioning_gen(cond_in) + + # if self.use_cond_net: + # extra_cond = self.cond_net(inp) + # cond = torch.cat([cond, extra_cond], dim=1) + # if self.use_input_stats: + # mu = inp.mean((2,3), keepdim=True) + # var = (inp - mu).pow(2).mean((2,3), keepdim=False) + # mu = mu.squeeze(-1).squeeze(-1) + # cond = torch.cat([cond, mu, var], dim=1) + + B, C, H, W = inp.shape + if self.rggb: + H = 2 * H + W = 2 * W + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + for encoder, down in zip(self.encoders, self.downs): + x = encoder((x, cond))[0] + encs.append(x) + x = down(x) + + x = self.middle_blks((x, cond))[0] + + for decoder, up, merge, enc_skip in zip(self.decoders, self.ups, self.merges, encs[::-1]): + x = up(x) + x = merge(x, enc_skip, cond) + x = decoder((x, cond))[0] + + x = self.ending(x) * self.alpha + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + +class ModelWrapper(nn.Module): + def __init__(self, **kwargs): + self.gamma = 1 + if 'gamma' in kwargs: + self.gamma = kwargs.pop('gamma') + super().__init__() + self.model = Restorer( + **kwargs + ) + + def forward(self, x, cond, residual): + x = x.clip(0, 1) ** (1. / self.gamma) + residual = residual.clip(0, 1) ** (1. / self.gamma) + output = self.model(x, cond) + output = (residual + output).clip(0, 1) ** (self.gamma) + return output + + +def make_full_model_RGGB(params, model_name=None): + model = ModelWrapper(**params) + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + + +class DemosaicingModelWrapper(nn.Module): + def __init__(self, **kwargs): + self.gamma = 1 + if 'gamma' in kwargs: + self.gamma = kwargs.pop('gamma') + super().__init__() + self.demosaicing = Debayer5x5() + self.model = Restorer( + **kwargs + ) + + + def forward(self, x, cond): + x = x.clip(0, 1) ** (1. / self.gamma) + bayer = nn.functional.pixel_shuffle(x, 2) + debayered = self.demosaicing(bayer) + output = (self.model(x, cond) + debayered).clip(0,1) ** (self.gamma) + return output + + +def make_full_model_RGGB_Demosaicing(params, model_name=None): + + model = DemosaicingModelWrapper(**params) + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + + +class DemosaicingModelLW(nn.Module): + def __init__(self, **kwargs): + super().__init__() + self.demosaicing = Debayer5x5() + self.ps = nn.PixelShuffle(2) + self.us = nn.PixelUnshuffle(2) + self.block = NAFBlock0(16, cond_chans=1) + self.pw_out = nn.Conv2d(4, 3, 1) + + + def forward(self, rggb, cond): + + bayer = self.ps(rggb, 2) + debayered = self.demosaicing(bayer) + shuffled_bayer = self.us(debayered) + x = torch.cat([rggb, shuffled_bayer], dim=1) + x = self.block((x, cond)) + x = self.ps(x) + x = self.pw_out(x) + output = x + debayered + return output + + + +class DemosaicingModelWrapperSequential(nn.Module): + def __init__(self, **kwargs): + self.gamma = 1 + if 'gamma' in kwargs: + self.gamma = kwargs.pop('gamma') + kwargs['rggb'] = False + kwargs['in_channels'] = 3 + super().__init__() + self.demosaicing = Debayer5x5() + self.model = Restorer( + **kwargs + ) + + + def forward(self, x, cond): + x = x.clip(0, 1) ** (1. / self.gamma) + bayer = nn.functional.pixel_shuffle(x, 2) + debayered = self.demosaicing(bayer) + output = (self.model(debayered, cond) + debayered).clip(0,1) ** (self.gamma) + return output + + +def make_full_model_RGGB_Demosaicing_Sequential(params, model_name=None): + + model = DemosaicingModelWrapperSequential(**params) + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + + + +class DemosaicingFromRGGB(nn.Module): + def __init__(self, **kwargs): + self.demosaicing = Debayer5x5() + self.model = Restorer( + **kwargs + ) + def forward(self, x, cond): + bayer = nn.functional.pixel_shuffle(x, 2) + debayered = self.demosaicing(bayer) + return debayered + \ No newline at end of file diff --git a/src/Restorer/debayering.py b/src/Restorer/debayering.py new file mode 100644 index 0000000..79b1bd0 --- /dev/null +++ b/src/Restorer/debayering.py @@ -0,0 +1,381 @@ +# Adapted from https://github.com/cheind/pytorch-debayer/blob/master/debayer/modules.py + +import torch +import torch.nn +import torch.nn.functional + +import enum + + +class Layout(enum.Enum): + """Possible Bayer color filter array layouts. + + The value of each entry is the color index (R=0,G=1,B=2) + within a 2x2 Bayer block. + """ + + RGGB = (0, 1, 1, 2) + GRBG = (1, 0, 2, 1) + GBRG = (1, 2, 0, 1) + BGGR = (2, 1, 1, 0) + +class Debayer3x3(torch.nn.Module): + """Demosaicing of Bayer images using 3x3 convolutions. + + Compared to Debayer2x2 this method does not use upsampling. + Instead, we identify five 3x3 interpolation kernels that + are sufficient to reconstruct every color channel at every + pixel location. + + We convolve the image with these 5 kernels using stride=1 + and a one pixel reflection padding. Finally, we gather + the correct channel values for each pixel location. Todo so, + we recognize that the Bayer pattern repeats horizontally and + vertically every 2 pixels. Therefore, we define the correct + index lookups for a 2x2 grid cell and then repeat to image + dimensions. + """ + + def __init__(self, layout: Layout = Layout.RGGB): + super(Debayer3x3, self).__init__() + self.layout = layout + # fmt: off + self.kernels = torch.nn.Parameter( + torch.tensor( + [ + [0, 0.25, 0], + [0.25, 0, 0.25], + [0, 0.25, 0], + + [0.25, 0, 0.25], + [0, 0, 0], + [0.25, 0, 0.25], + + [0, 0, 0], + [0.5, 0, 0.5], + [0, 0, 0], + + [0, 0.5, 0], + [0, 0, 0], + [0, 0.5, 0], + ] + ).view(4, 1, 3, 3), + requires_grad=False, + ) + # fmt: on + + self.index = torch.nn.Parameter( + self._index_from_layout(layout), + requires_grad=False, + ) + + def forward(self, x): + """Debayer image. + + Parameters + ---------- + x : Bx1xHxW tensor + Images to debayer + + Returns + ------- + rgb : Bx3xHxW tensor + Color images in RGB channel order. + """ + B, C, H, W = x.shape + + xpad = torch.nn.functional.pad(x, (1, 1, 1, 1), mode="reflect") + c = torch.nn.functional.conv2d(xpad, self.kernels, stride=1) + c = torch.cat((c, x), 1) # Concat with input to give identity kernel Bx5xHxW + + rgb = torch.gather( + c, + 1, + self.index.repeat( + 1, + 1, + torch.div(H, 2, rounding_mode="floor"), + torch.div(W, 2, rounding_mode="floor"), + ).expand( + B, -1, -1, -1 + ), # expand in batch is faster than repeat + ) + return rgb + + def _index_from_layout(self, layout: Layout) -> torch.Tensor: + """Returns a 1x3x2x2 index tensor for each color RGB in a 2x2 bayer tile. + + Note, the index corresponding to the identity kernel is 4, which will be + correct after concatenating the convolved output with the input image. + """ + # ... + # ... b g b g ... + # ... g R G r ... + # ... b G B g ... + # ... g r g r ... + # ... + # fmt: off + rggb = torch.tensor( + [ + # dest channel r + [4, 2], # pixel is R,G1 + [3, 1], # pixel is G2,B + # dest channel g + [0, 4], # pixel is R,G1 + [4, 0], # pixel is G2,B + # dest channel b + [1, 3], # pixel is R,G1 + [2, 4], # pixel is G2,B + ] + ).view(1, 3, 2, 2) + # fmt: on + return { + Layout.RGGB: rggb, + Layout.GRBG: torch.roll(rggb, 1, -1), + Layout.GBRG: torch.roll(rggb, 1, -2), + Layout.BGGR: torch.roll(rggb, (1, 1), (-1, -2)), + }.get(layout) + + +class Debayer2x2(torch.nn.Module): + """Fast demosaicing of Bayer images using 2x2 convolutions. + + This method uses 3 kernels of size 2x2 and stride 2. Each kernel + corresponds to a single color RGB. For R and B the corresponding + value from each 2x2 Bayer block is taken according to the layout. + For G, G1 and G2 are averaged. The resulting image has half width/ + height and is upsampled by a factor of 2. + """ + + def __init__(self, layout: Layout = Layout.RGGB): + super(Debayer2x2, self).__init__() + self.layout = layout + + self.kernels = torch.nn.Parameter( + self._kernels_from_layout(layout), + requires_grad=False, + ) + + def forward(self, x): + """Debayer image. + + Parameters + ---------- + x : Bx1xHxW tensor + Images to debayer + + Returns + ------- + rgb : Bx3xHxW tensor + Color images in RGB channel order. + """ + x = torch.nn.functional.conv2d(x, self.kernels, stride=2) + + x = torch.nn.functional.interpolate( + x, scale_factor=2, mode="bilinear", align_corners=False + ) + return x + + def _kernels_from_layout(self, layout: Layout) -> torch.Tensor: + v = torch.tensor(layout.value).reshape(2, 2) + r = torch.zeros(2, 2) + r[v == 0] = 1.0 + + g = torch.zeros(2, 2) + g[v == 1] = 0.5 + + b = torch.zeros(2, 2) + b[v == 2] = 1.0 + + k = torch.stack((r, g, b), 0).unsqueeze(1) # 3x1x2x2 + return k + + +class DebayerSplit(torch.nn.Module): + """Demosaicing of Bayer images using 3x3 green convolution and red,blue upsampling. + Requires Bayer layout `Layout.RGGB`. + """ + + def __init__(self, layout: Layout = Layout.RGGB): + super().__init__() + if layout != Layout.RGGB: + raise NotImplementedError("DebayerSplit only implemented for RGGB layout.") + self.layout = layout + + self.pad = torch.nn.ReflectionPad2d(1) + self.kernel = torch.nn.Parameter( + torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]])[None, None] * 0.25 + ) + + def forward(self, x): + """Debayer image. + + Parameters + ---------- + x : Bx1xHxW tensor + Images to debayer + + Returns + ------- + rgb : Bx3xHxW tensor + Color images in RGB channel order. + """ + B, _, H, W = x.shape + red = x[:, :, ::2, ::2] + blue = x[:, :, 1::2, 1::2] + + green = torch.nn.functional.conv2d(self.pad(x), self.kernel) + green[:, :, ::2, 1::2] = x[:, :, ::2, 1::2] + green[:, :, 1::2, ::2] = x[:, :, 1::2, ::2] + + return torch.cat( + ( + torch.nn.functional.interpolate( + red, size=(H, W), mode="bilinear", align_corners=False + ), + green, + torch.nn.functional.interpolate( + blue, size=(H, W), mode="bilinear", align_corners=False + ), + ), + dim=1, + ) + + +class Debayer5x5(torch.nn.Module): + """Demosaicing of Bayer images using Malver-He-Cutler algorithm. + + Requires BG-Bayer color filter array layout. That is, + the image[1,1]='B', image[1,2]='G'. This corresponds + to OpenCV naming conventions. + + Compared to Debayer2x2 this method does not use upsampling. + Compared to Debayer3x3 the algorithm gives sharper edges and + less chromatic effects. + + ## References + Malvar, Henrique S., Li-wei He, and Ross Cutler. + "High-quality linear interpolation for demosaicing of Bayer-patterned + color images." 2004 + """ + + def __init__(self, layout: Layout = Layout.RGGB): + super(Debayer5x5, self).__init__() + self.layout = layout + # fmt: off + self.kernels = torch.nn.Parameter( + torch.tensor( + [ + # G at R,B locations + # scaled by 16 + [ 0, 0, -2, 0, 0], # noqa + [ 0, 0, 4, 0, 0], # noqa + [-2, 4, 8, 4, -2], # noqa + [ 0, 0, 4, 0, 0], # noqa + [ 0, 0, -2, 0, 0], # noqa + + # R,B at G in R rows + # scaled by 16 + [ 0, 0, 1, 0, 0], # noqa + [ 0, -2, 0, -2, 0], # noqa + [-2, 8, 10, 8, -2], # noqa + [ 0, -2, 0, -2, 0], # noqa + [ 0, 0, 1, 0, 0], # noqa + + # R,B at G in B rows + # scaled by 16 + [ 0, 0, -2, 0, 0], # noqa + [ 0, -2, 8, -2, 0], # noqa + [ 1, 0, 10, 0, 1], # noqa + [ 0, -2, 8, -2, 0], # noqa + [ 0, 0, -2, 0, 0], # noqa + + # R at B and B at R + # scaled by 16 + [ 0, 0, -3, 0, 0], # noqa + [ 0, 4, 0, 4, 0], # noqa + [-3, 0, 12, 0, -3], # noqa + [ 0, 4, 0, 4, 0], # noqa + [ 0, 0, -3, 0, 0], # noqa + + # R at R, B at B, G at G + # identity kernel not shown + ] + ).view(4, 1, 5, 5).float() / 16.0, + requires_grad=False, + ) + # fmt: on + + self.index = torch.nn.Parameter( + # Below, note that index 4 corresponds to identity kernel + self._index_from_layout(layout), + requires_grad=False, + ) + + def forward(self, x): + """Debayer image. + + Parameters + ---------- + x : Bx1xHxW tensor + Images to debayer + + Returns + ------- + rgb : Bx3xHxW tensor + Color images in RGB channel order. + """ + B, C, H, W = x.shape + + xpad = torch.nn.functional.pad(x, (2, 2, 2, 2), mode="reflect") + planes = torch.nn.functional.conv2d(xpad, self.kernels, stride=1) + planes = torch.cat( + (planes, x), 1 + ) # Concat with input to give identity kernel Bx5xHxW + rgb = torch.gather( + planes, + 1, + self.index.repeat( + 1, + 1, + torch.div(H, 2, rounding_mode="floor"), + torch.div(W, 2, rounding_mode="floor"), + ).expand( + B, -1, -1, -1 + ), # expand for singleton batch dimension is faster + ) + return torch.clamp(rgb, 0, 1) + + def _index_from_layout(self, layout: Layout) -> torch.Tensor: + """Returns a 1x3x2x2 index tensor for each color RGB in a 2x2 bayer tile. + + Note, the index corresponding to the identity kernel is 4, which will be + correct after concatenating the convolved output with the input image. + """ + # ... + # ... b g b g ... + # ... g R G r ... + # ... b G B g ... + # ... g r g r ... + # ... + # fmt: off + rggb = torch.tensor( + [ + # dest channel r + [4, 1], # pixel is R,G1 + [2, 3], # pixel is G2,B + # dest channel g + [0, 4], # pixel is R,G1 + [4, 0], # pixel is G2,B + # dest channel b + [3, 2], # pixel is R,G1 + [1, 4], # pixel is G2,B + ] + ).view(1, 3, 2, 2) + # fmt: on + return { + Layout.RGGB: rggb, + Layout.GRBG: torch.roll(rggb, 1, -1), + Layout.GBRG: torch.roll(rggb, 1, -2), + Layout.BGGR: torch.roll(rggb, (1, 1), (-1, -2)), + }.get(layout) \ No newline at end of file From 42aae649a2bde53d36afc9eb2da8acc919afe604 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Sat, 1 Nov 2025 15:39:38 -0400 Subject: [PATCH 52/56] Demosaicing training scripts --- 1_train_model_raw_demosaicing.ipynb | 22 ++++++- 1_validated_raw_demosaicing.ipynb | 96 +++++++++++++++++++++++++---- config_demosaicing.yaml | 24 ++++---- 3 files changed, 118 insertions(+), 24 deletions(-) diff --git a/1_train_model_raw_demosaicing.ipynb b/1_train_model_raw_demosaicing.ipynb index 6543f21..9c95d38 100644 --- a/1_train_model_raw_demosaicing.ipynb +++ b/1_train_model_raw_demosaicing.ipynb @@ -31,7 +31,7 @@ "from src.training.train_loop import train_one_epoch, visualize\n", "from src.training.utils import apply_gamma_torch\n", "from src.training.load_config import load_config\n", - "from src.Restorer.Cond_NAF import make_full_model_RGGB_Demosaicing\n" + "from src.Restorer.Cond_NAF_demosaic import make_full_model_RGGB_Demosaicing\n" ] }, { @@ -76,6 +76,16 @@ "params = {**run_config}" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "48e645a0", + "metadata": {}, + "outputs": [], + "source": [ + "mlflow_path" + ] + }, { "cell_type": "code", "execution_count": null, @@ -150,6 +160,16 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "58244dcc", + "metadata": {}, + "outputs": [], + "source": [ + "loss_fn = nn.L1Loss()" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/1_validated_raw_demosaicing.ipynb b/1_validated_raw_demosaicing.ipynb index b36bc93..3698218 100644 --- a/1_validated_raw_demosaicing.ipynb +++ b/1_validated_raw_demosaicing.ipynb @@ -84,7 +84,7 @@ "outputs": [], "source": [ "\n", - "RUN_ID = \"1e253a47ff7d43e1a432cce4ed083c7b\" \n", + "RUN_ID = \"01a8dc61c70c47f99e3b846f21c11fb7\" \n", "ARTIFACT_PATH = run_config['run_path']\n", "\n", "model_uri = f\"runs:/{RUN_ID}/{ARTIFACT_PATH}\"\n", @@ -99,6 +99,18 @@ " print(f\"Error loading model from MLflow: {e}\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "deaae86d", + "metadata": {}, + "outputs": [], + "source": [ + "from src.Restorer.debayering import Debayer5x5\n", + "debayer = Debayer5x5()\n", + "debayer = debayer.to(device)\n" + ] + }, { "cell_type": "code", "execution_count": null, @@ -186,7 +198,7 @@ "def visualize(idxs, _model, dataset, _device, _loss_func, rggb=False):\n", " import matplotlib.pyplot as plt\n", " _model.train()\n", - " total_loss, n_images, total_final_image_loss = 0.0, 0, 0.0\n", + " total_loss, total_loss_malhevar, n_images, total_final_image_loss = 0.0, 0.0, 0, 0.0\n", " start = perf_counter()\n", "\n", " for idx in idxs:\n", @@ -194,17 +206,24 @@ " conditioning = row['conditioning'].unsqueeze(0).float().to(_device)\n", " gt = row['ground_truth'].unsqueeze(0).float().to(_device)\n", " sparse = row['cfa_sparse'].unsqueeze(0).float().to(_device)\n", + " rggb_values = row['cfa_rggb'].unsqueeze(0).float().to(_device)\n", + " bayer = nn.functional.pixel_shuffle(rggb_values, 2)\n", + "\n", + " debayered = debayer(bayer) \n", " input = sparse\n", " if rggb:\n", - " input = row['cfa_rggb'].unsqueeze(0).float().to(_device)\n", + " input = rggb_values\n", "\n", " conditioning = make_conditioning(conditioning, _device)\n", " \n", " with torch.no_grad():\n", " with torch.autocast(device_type=\"mps\", dtype=torch.bfloat16):\n", - " pred = _model(input, conditioning) \n", + " pred = _model(input, conditioning)\n", + " # pred = debayered \n", " loss = _loss_func(pred, gt)\n", + " loss_mal_he_var = _loss_func(debayered, gt)\n", "\n", + " total_loss_malhevar += float(loss_mal_he_var.detach().cpu())\n", " total_loss += float(loss.detach().cpu())\n", " n_images += gt.shape[0]\n", "\n", @@ -220,11 +239,14 @@ " plt.imshow(pred)\n", "\n", " plt.subplot(2, 3, 2)\n", - " plt.title('Menon')\n", - " bayer = sparse[0].sum(axis=0)\n", - " bayer = apply_gamma_torch(bayer).cpu().numpy()\n", - " trad_demosaiced = demosaicing_CFA_Bayer_Menon2007(bayer)\n", - " plt.imshow(trad_demosaiced)\n", + " # plt.title('Menon')\n", + " # bayer = sparse[0].sum(axis=0)\n", + " # bayer = apply_gamma_torch(bayer).cpu().numpy()\n", + " # trad_demosaiced = demosaicing_CFA_Bayer_Menon2007(bayer)\n", + " # plt.imshow(trad_demosaiced)\n", + " plt.title('Mal He Var')\n", + " debayered_numpy = apply_gamma_torch(debayered)[0].permute(1, 2, 0).cpu().numpy()\n", + " plt.imshow(debayered_numpy)\n", "\n", " plt.subplot(2, 3, 3)\n", " plt.title('gt')\n", @@ -237,16 +259,17 @@ "\n", "\n", " plt.subplot(2, 3, 5)\n", - " plt.imshow(trad_demosaiced - pred.cpu().numpy() + 0.5)\n", + " plt.imshow(debayered_numpy - pred.cpu().numpy() + 0.5)\n", "\n", " plt.subplot(2, 3, 6)\n", - " plt.imshow(trad_demosaiced - gt.cpu().numpy() + 0.5)\n", + " plt.imshow(debayered_numpy - gt.cpu().numpy() + 0.5)\n", " plt.show()\n", " plt.clf()\n", "\n", " n_images = len(idxs)\n", " print(\n", " f\"Train loss: {total_loss/n_images:.6f} \"\n", + " f\"Train loss (MHV): {total_loss_malhevar/n_images:.6f} \"\n", " f\"Final image val loss: {total_final_image_loss/n_images:.6f} \"\n", " f\"Time: {perf_counter()-start:.1f}s \"\n", " f\"Images: {n_images}\")\n", @@ -277,6 +300,57 @@ "source": [ "visualize(matching_indices_in_subset, model, val_dataset, device, loss_fn, rggb=rggb)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f092eecc", + "metadata": {}, + "outputs": [], + "source": [ + "from src.Restorer.debayering import Debayer5x5" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35fd5ed9", + "metadata": {}, + "outputs": [], + "source": [ + "output = dataset[0]\n", + "\n", + "rggb_values = output['cfa_rggb'].unsqueeze(0).float()\n", + "bayer = nn.functional.pixel_shuffle(rggb_values, 2).to('mps')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfa7d2e9", + "metadata": {}, + "outputs": [], + "source": [ + "im = debayer(bayer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12e6a557", + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(im[0].cpu().permute(1, 2, 0))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e90dd490", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/config_demosaicing.yaml b/config_demosaicing.yaml index 6ecabe2..bb0f980 100644 --- a/config_demosaicing.yaml +++ b/config_demosaicing.yaml @@ -1,14 +1,14 @@ # config.yaml # --- Paths --- -base_data_dir: /Volumes/EasyStore/RAWNIND/ -jpeg_output_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG -cropped_jpeg_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG +# base_data_dir: /Volumes/EasyStore/RAWNIND/ +# jpeg_output_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG +# cropped_jpeg_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG cropped_raw_subdir: /Users/ryanmueller/Develop/Cropped_Raw/ cropped_raw_size: 2000 align_csv: align_data.csv secondary_align_csv: align_phase_v2.csv -script_path: /Volumes/EasyStore/models/traces -mlflow_path: /Volumes/EasyStore/models/mlfow +script_path: /Users/ryanmueller/Develop/traces/ +mlflow_path: /Users/ryanmueller/Develop/mlflow/ # --- Training Params --- colorspace: lin_rec2020 @@ -17,11 +17,11 @@ batch_size: 2 crop_size: 256 lr_base: 2.5e-5 clipping: 1e-2 -num_epochs_pretraining: 75 +num_epochs_pretraining: 5 num_epochs_finetuning: 20 val_split: 0.2 random_seed: 42 -cosine_annealing: True +cosine_annealing: False iso_range: [0, 999999] # --- Experiment Settings --- @@ -29,7 +29,7 @@ experiment: Demosaicing mlflow_experiment: Demosaicing # --- Run Configuration ---: -run_name: Demosaicing_test +run_name: Demosaicing_test_5x5_just_l1 run_path: Demosaicing_Test model_params: chans: [12, 24] @@ -50,10 +50,10 @@ model_params: gamma: 1.0 # --- Loss Configureation ---: -alpha: 0.2 +alpha: 1.0 beta: 5.0 -l1_weight: 0.16 -ssim_weight: 0.84 +l1_weight: 1 +ssim_weight: 0 tv_weight: 0.0 vgg_loss_weight: 0.0 -percept_loss_weight: 0.025 \ No newline at end of file +percept_loss_weight: 0.0 \ No newline at end of file From 1401b8e0f4f7330d2c69b1005dc8ce8c47774c67 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Tue, 4 Nov 2025 00:28:38 -0500 Subject: [PATCH 53/56] demosaicing models --- src/Restorer/Cond_NAF_demosaic.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/Restorer/Cond_NAF_demosaic.py b/src/Restorer/Cond_NAF_demosaic.py index ac2089d..fd872d4 100644 --- a/src/Restorer/Cond_NAF_demosaic.py +++ b/src/Restorer/Cond_NAF_demosaic.py @@ -927,11 +927,10 @@ def make_full_model_RGGB_Demosaicing_Sequential(params, model_name=None): class DemosaicingFromRGGB(nn.Module): - def __init__(self, **kwargs): + def __init__(self): + super().__init__() self.demosaicing = Debayer5x5() - self.model = Restorer( - **kwargs - ) + def forward(self, x, cond): bayer = nn.functional.pixel_shuffle(x, 2) debayered = self.demosaicing(bayer) From 2e738238de6b824437b8a184f48394a3f92cf049 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Tue, 4 Nov 2025 00:32:41 -0500 Subject: [PATCH 54/56] Built in demosaicing for residual --- src/Restorer/Cond_NAF.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/Restorer/Cond_NAF.py b/src/Restorer/Cond_NAF.py index b43cda2..f710cf6 100644 --- a/src/Restorer/Cond_NAF.py +++ b/src/Restorer/Cond_NAF.py @@ -838,25 +838,35 @@ def make_full_model_RGGB(params, model_name=None): return model -class DemosaicingModelWrapper(nn.Module): + +from src.Restorer.Cond_NAF_demosaic import DemosaicingFromRGGB + +class ModelWrapperNoRes(nn.Module): def __init__(self, **kwargs): + super().__init__() + self.gamma = 1 if 'gamma' in kwargs: self.gamma = kwargs.pop('gamma') - super().__init__() + + self.demosaicer = DemosaicingFromRGGB() self.model = Restorer( **kwargs ) - def forward(self, x, cond): - x = x.clip(0, 1) ** (1. / self.gamma) - output = (self.model(x, cond)).clip(0,1) ** (self.gamma) + def forward(self, rggb, cond): + rggb = rggb.clip(0, 1) ** (1. / self.gamma) + debayered = self.demosaicer(rggb, cond) + debayered = debayered.clip(0, 1) ** (1. / self.gamma) + output = self.model(rggb, cond) + output = (debayered + output).clip(0, 1) ** (self.gamma) return output -def make_full_model_RGGB_Demosaicing(params, model_name=None): - model = DemosaicingModelWrapper(**params) +def make_full_model_RGGB_NoRes(params, model_name=None): + model = ModelWrapperNoRes(**params) if not model_name is None: state_dict = torch.load(model_name, map_location="cpu") model.load_state_dict(state_dict) - return model \ No newline at end of file + return model + From 8715122162b5517897d2932da64adc654cbfb8a7 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Wed, 5 Nov 2025 08:50:54 -0500 Subject: [PATCH 55/56] Checkpoint config --- config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config.yaml b/config.yaml index 0467e0c..f85074d 100644 --- a/config.yaml +++ b/config.yaml @@ -6,7 +6,7 @@ cropped_jpeg_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG cropped_raw_subdir: /Users/ryanmueller/Develop/Cropped_Raw/ cropped_raw_size: 2000 align_csv: align_data.csv -secondary_align_csv: align_phase_v2.csv +secondary_align_csv: align_phase_v2_exposure_corr.csv script_path: /Volumes/EasyStore/models/traces mlflow_path: /Volumes/EasyStore/models/mlfow @@ -29,7 +29,7 @@ experiment: NAF_test mlflow_experiment: NAFNet_variations # --- Run Configuration ---: -run_name: NAF_gamma_test +run_name: NAF_corr_DIST run_path: NAF_deep_test_align model_params: chans: [32, 64, 128, 256, 256, 256] From 3cc2acb871567e873764c095bc74c7ebb4fbdf24 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Wed, 5 Nov 2025 12:29:43 -0500 Subject: [PATCH 56/56] Dataset version that loads raw handlers into memory --- .../RawDatasetDNG_load_into_memory.py | 156 ++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 src/training/RawDatasetDNG_load_into_memory.py diff --git a/src/training/RawDatasetDNG_load_into_memory.py b/src/training/RawDatasetDNG_load_into_memory.py new file mode 100644 index 0000000..bafd62f --- /dev/null +++ b/src/training/RawDatasetDNG_load_into_memory.py @@ -0,0 +1,156 @@ +import pandas as pd +import os +from torch.utils.data import Dataset +import imageio +from colour_demosaicing import ( + ROOT_RESOURCES_EXAMPLES, + demosaicing_CFA_Bayer_bilinear, + demosaicing_CFA_Bayer_Malvar2004, + demosaicing_CFA_Bayer_Menon2007, + mosaicing_CFA_Bayer) + +# from src.training.utils import inverse_gamma_tone_curve, cfa_to_sparse +import numpy as np +import torch +# from src.training.align_images import apply_alignment, align_clean_to_noisy +from pathlib import Path +from RawHandler.RawHandler import RawHandler + +class RawDatasetDNG(Dataset): + def __init__(self, path, csv, colorspace, crop_size=180, buffer=10, + validation=False, run_align=False, + dimensions=2000, + apply_exposure_corr=True, + demosaicing_func = demosaicing_CFA_Bayer_Malvar2004): + super().__init__() + self.df = pd.read_csv(csv) + self.path = path + self.crop_size = crop_size + self.buffer = buffer + self.coordinate_iso = 6400 + self.validation=validation + self.run_align = run_align + self.dtype = np.float16 + self.dimensions = dimensions + self.colorspace = colorspace + self.apply_exposure_corr = apply_exposure_corr + self.demosaicing_func = demosaicing_func + + files = os.listdir(path) + files = [f for f in files if 'dng' in f] + files = [f for f in files if not 'xmp' in f] + self.rhs = {} + for file in files: + self.rhs[file] = RawHandler(f'Cropped_Raw/{file}') + + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + # Load images + name = Path(f"{row.bayer_path}").name + name = name.replace('_bayer.jpg', '.dng') + noisy_rh = self.rhs[name] + + gt_name = Path(f"{row.gt_path}").name + gt_name = gt_name.replace('.jpg', '.dng') + gt_rh = self.rhs[gt_name] + + dims = random_crop_dim(noisy_rh.raw.shape, self.crop_size, self.buffer, validation=self.validation) + + bayer_data = noisy_rh.apply_colorspace_transform(dims=dims, colorspace=self.colorspace) + noisy = noisy_rh.as_rgb(dims=dims, colorspace=self.colorspace, demosaicing_func=self.demosaicing_func, clip=False) + rggb = noisy_rh.as_rggb(dims=dims, colorspace=self.colorspace, clip=False) + sparse = noisy_rh.as_sparse(dims=dims, colorspace=self.colorspace, clip=False) + + check_align_matrix(row) + expanded_dims = [dims[0]-self.buffer, dims[1]+self.buffer, dims[2]-self.buffer, dims[3]+self.buffer] + gt_expanded = gt_rh.as_rgb(dims=expanded_dims, colorspace=self.colorspace, demosaicing_func=self.demosaicing_func, clip=False) + if self.apply_exposure_corr: + gt_expanded[0] *= row['r_scale_factor'] + gt_expanded[1] *= row['g_scale_factor'] + gt_expanded[2] *= row['b_scale_factor'] + aligned = apply_alignment(gt_expanded.transpose(1, 2, 0), row.to_dict())[self.buffer:-self.buffer, self.buffer:-self.buffer] + gt_non_aligned = gt_expanded.transpose(1, 2, 0)[self.buffer:-self.buffer, self.buffer:-self.buffer] + # # gt_non_aligned = gt_non_aligned * noisy.mean() / aligned.mean() + # # aligned = aligned * noisy.mean() / aligned.mean() + # # Get Raw data for testing + # noisy_raw = noisy_rh.raw[dims[0]:dims[1], dims[2]: dims[3]] + # row_dict = row.to_dict() + # shift_y, shift_x = round_to_nearest_2(row_dict['M12']), round_to_nearest_2(row_dict['M11']) + # gt_raw = gt_rh.raw[dims[0]+shift_y:dims[1]+shift_y, dims[2]+shift_x:dims[3]+shift_x] + # aligned = gt_rh.as_rgb(dims=dims, colorspace=self.colorspace).transpose(1, 2, 0) + + # Convert to tensors + output = { + "bayer": torch.tensor(bayer_data).to(float).clip(0,1), + "gt_non_aligned": torch.tensor(gt_non_aligned).to(float).permute(2, 0, 1).clip(0,1), + "aligned": torch.tensor(aligned).to(float).permute(2, 0, 1).clip(0,1), + "sparse": torch.tensor(sparse).to(float).clip(0,1), + "noisy": torch.tensor(noisy).to(float).clip(0,1), + "rggb": torch.tensor(rggb).to(float).clip(0,1), + "conditioning": torch.tensor([row.iso/self.coordinate_iso]).to(float), + # "noisy_raw": noisy_raw, + # "gt_raw": gt_raw, + # "noise_est": noise_est, + # "rggb_gt": rggb_gt, + } + return output + + + + +def global_affine_match(A, D, mask=None): + """ + Fit D ≈ a + b*A with least squares. + A, D : 2D arrays, same shape (linear values) + mask : optional boolean array, True=use pixel + returns: a, b, D_pred, D_resid (D - (a + b*A)) + """ + A = A.ravel().astype(np.float64) + D = D.ravel().astype(np.float64) + if mask is None: + mask = np.isfinite(A) & np.isfinite(D) + else: + mask = mask.ravel() & np.isfinite(A) & np.isfinite(D) + + A0 = A[mask] + D0 = D[mask] + # design matrix [1, A] + X = np.vstack([np.ones_like(A0), A0]).T + coef, *_ = np.linalg.lstsq(X, D0, rcond=None) + a, b = coef[0], coef[1] + D_pred = (a + b * A).reshape(-1) + D_pred = D_pred.reshape(A.shape) if False else (a + b * A).reshape((-1,)) # keep flatten + + return a, b, (a + b * A) + + +def random_crop_dim(shape, crop_size, buffer, validation=False): + h, w = shape + if not validation: + top = np.random.randint(0 + buffer, h - crop_size - buffer) + left = np.random.randint(0 + buffer, w - crop_size - buffer) + else: + top = (h - crop_size) // 2 + left = (w - crop_size) // 2 + + if top % 2 != 0: top = top - 1 + if left % 2 != 0: left = left - 1 + bottom = top + crop_size + right = left + crop_size + return (left, right, top, bottom) + +def check_align_matrix(row, tolerance=1e-7): + is_identity = np.isclose(row['M00'], 1.0, atol=tolerance) and \ + np.isclose(row['M01'], 0.0, atol=tolerance) and \ + np.isclose(row['M10'], 0.0, atol=tolerance) and \ + np.isclose(row['M11'], 1.0, atol=tolerance) + + assert is_identity, "Rotations, scalings, or shearing are not tested." + + +def round_to_nearest_2(number): + return round(number / 2) * 2 \ No newline at end of file