Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
393 changes: 393 additions & 0 deletions 1_compute_exposure.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,393 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "f6351e77",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from torch.utils.data import DataLoader, random_split\n",
"import copy\n",
"from pathlib import Path\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2043dc8e",
"metadata": {},
"outputs": [],
"source": [
"from src.training.RawDatasetDNG import RawDatasetDNG\n",
"from src.training.load_config import load_config\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02b3a558",
"metadata": {},
"outputs": [],
"source": [
"from src.training.censored_fit import censored_linear_fit_twosided\n",
"from pytorch_msssim import ms_ssim\n",
"import torch.nn as nn\n",
"from tqdm import tqdm\n",
"import pandas as pd\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b14af214",
"metadata": {},
"outputs": [],
"source": [
"run_config = load_config()\n",
"dataset_path = Path(run_config['cropped_raw_subdir'])\n",
"align_csv = dataset_path / run_config['secondary_align_csv']\n",
"new_align_csv = str(align_csv).replace('.csv', '_exposure_corr.csv')\n",
"device=run_config['device']\n",
"\n",
"batch_size = run_config['batch_size']\n",
"lr = run_config['lr_base'] * batch_size\n",
"clipping = run_config['clipping']\n",
"\n",
"num_epochs = run_config['num_epochs_pretraining']\n",
"val_split = run_config['val_split']\n",
"crop_size = run_config['crop_size']\n",
"experiment = run_config['mlflow_experiment']\n",
"model_params = run_config['model_params']\n",
"rggb = model_params['rggb']\n",
"\n",
"colorspace = run_config['colorspace']\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fcab6284",
"metadata": {},
"outputs": [],
"source": [
"crop_size = 1900"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15f16fa7",
"metadata": {},
"outputs": [],
"source": [
"dataset = RawDatasetDNG(dataset_path, align_csv, colorspace, crop_size=crop_size, validation=True, apply_exposure_corr=False)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55c0b09e",
"metadata": {},
"outputs": [],
"source": [
"def compute_scale_factors(aligned, noisy):\n",
" scale_factors = []\n",
" sigmas = []\n",
"\n",
" for channel in range(3):\n",
" _, b, sigma = censored_linear_fit_twosided(aligned[channel], noisy[channel], include_offset=False)\n",
" scale_factors.append(b)\n",
" sigmas.append(sigma)\n",
" return scale_factors, sigmas"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e08dd7c4",
"metadata": {},
"outputs": [],
"source": [
"row_list = []\n",
"for i in tqdm(range(len(dataset))):\n",
" try:\n",
" output = dataset[i]\n",
" scale_factors, sigmas = compute_scale_factors(output['aligned'], output['noisy'])\n",
" row = dataset.df.iloc[i].to_dict()\n",
" row['r_scale_factor'] = scale_factors[0]\n",
" row['g_scale_factor'] = scale_factors[1]\n",
" row['b_scale_factor'] = scale_factors[2]\n",
"\n",
" row['r_sigma'] = sigmas[0]\n",
" row['g_sigma'] = sigmas[1]\n",
" row['b_sigma'] = sigmas[2]\n",
"\n",
" aligned, noisy = output['aligned'], output['noisy']\n",
" uncorrected_l1 = nn.functional.l1_loss(aligned, noisy).item()\n",
" row['uncorrected_l1'] = uncorrected_l1\n",
" aligned[0] *= scale_factors[0]\n",
" aligned[1] *= scale_factors[1]\n",
" aligned[2] *= scale_factors[2]\n",
"\n",
" corrected_l1 = nn.functional.l1_loss(aligned, noisy).item()\n",
" row['corrected_l1'] = corrected_l1\n",
"\n",
" row_list.append(row)\n",
" except:\n",
" print(i)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df486c40",
"metadata": {},
"outputs": [],
"source": [
"align_csv = str(align_csv).replace('.csv', '_exposure_corr.csv')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19ec5ad5",
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame(row_list)\n",
"df.to_csv(new_align_csv)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "80d2ca3a",
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv(new_align_csv)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "39be207b",
"metadata": {},
"outputs": [],
"source": [
"# Comparing the corrected l1 with uncorrected one to verify the difference has reduced\n",
"plt.scatter(df.uncorrected_l1, df.corrected_l1)\n",
"plt.plot((0, 0.06), (0, 0.06), color='red')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "116d7049",
"metadata": {},
"outputs": [],
"source": [
"df.sort_values('r_scale_factor')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "090dc759",
"metadata": {},
"outputs": [],
"source": [
"dataset = RawDatasetDNG(dataset_path, new_align_csv, colorspace, crop_size=crop_size, validation=True, apply_exposure_corr=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eeccfcc6",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"mask = dataset.df.r_scale_factor.values > 0 \n",
"matching_indices_in_subset = np.nonzero(mask)[0]\n",
"matching_indices_in_subset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e2f6e368",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ce736d8",
"metadata": {},
"outputs": [],
"source": [
"# Min r_scale_factor\n",
"\n",
"for idx in tqdm(matching_indices_in_subset):\n",
" row = dataset[idx]\n",
"\n",
" aligned = row['aligned'].permute(1, 2, 0)**(1/2.2)\n",
" noisy = row['noisy'].permute(1, 2, 0)**(1/2.2)\n",
"\n",
" plt.subplots(1, 3, figsize=(15, 5))\n",
"\n",
" plt.subplot(1, 3, 1)\n",
" plt.imshow(aligned)\n",
"\n",
" plt.subplot(1, 3, 2)\n",
" plt.imshow(noisy)\n",
" plt.title(idx)\n",
" plt.subplot(1, 3, 3)\n",
" plt.imshow(aligned-noisy+0.5)\n",
" plt.savefig(f'/Volumes/EasyStore/cropped_raw_comp/{idx}.jpeg')\n",
" plt.clf()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbf4dc3c",
"metadata": {},
"outputs": [],
"source": [
"idx = 79\n",
"row = dataset[79]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc0a759e",
"metadata": {},
"outputs": [],
"source": [
"plt.subplots(1, 3, figsize=(15, 5))\n",
"aligned = row['aligned'].permute(1, 2, 0)**(1/2.2)\n",
"noisy = row['noisy'].permute(1, 2, 0)**(1/2.2)\n",
"plt.subplot(1, 3, 1)\n",
"plt.imshow(aligned)\n",
"\n",
"plt.subplot(1, 3, 2)\n",
"plt.imshow(noisy)\n",
"plt.title(idx)\n",
"plt.subplot(1, 3, 3)\n",
"plt.imshow(aligned-noisy+0.5)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f5a99a2",
"metadata": {},
"outputs": [],
"source": [
"def compute_scale_factors(aligned, noisy):\n",
" scale_factors = []\n",
" offsets = []\n",
"\n",
" sigmas = []\n",
"\n",
" for channel in range(3):\n",
" a, b, sigma = censored_linear_fit_twosided(aligned[channel], noisy[channel], include_offset=True)\n",
" scale_factors.append(b)\n",
" offsets.append(a)\n",
" sigmas.append(sigma)\n",
" return scale_factors, offsets, sigmas"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "adf31c78",
"metadata": {},
"outputs": [],
"source": [
"\n",
"scale_factors, offsets, sigmas = compute_scale_factors(row['aligned'], row['noisy'])\n",
"\n",
"row['r_scale_factor'] = scale_factors[0]\n",
"row['g_scale_factor'] = scale_factors[1]\n",
"row['b_scale_factor'] = scale_factors[2]\n",
"\n",
"row['r_sigma'] = sigmas[0]\n",
"row['g_sigma'] = sigmas[1]\n",
"row['b_sigma'] = sigmas[2]\n",
"\n",
"aligned, noisy = row['aligned'], row['noisy']\n",
"uncorrected_l1 = nn.functional.l1_loss(aligned, noisy).item()\n",
"row['uncorrected_l1'] = uncorrected_l1\n",
"aligned[0] *= scale_factors[0]\n",
"aligned[0] += offsets[0]\n",
"aligned[1] *= scale_factors[1]\n",
"aligned[1] += offsets[1]\n",
"\n",
"aligned[2] *= scale_factors[2]\n",
"aligned[2] += offsets[2]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c3e5494f",
"metadata": {},
"outputs": [],
"source": [
"plt.subplots(1, 3, figsize=(15, 5))\n",
"aligned = aligned.permute(1, 2, 0)**(1/2.2)\n",
"noisy = row['noisy'].permute(1, 2, 0)**(1/2.2)\n",
"plt.subplot(1, 3, 1)\n",
"plt.imshow(aligned)\n",
"\n",
"plt.subplot(1, 3, 2)\n",
"plt.imshow(noisy)\n",
"plt.title(idx)\n",
"plt.subplot(1, 3, 3)\n",
"plt.imshow(aligned-noisy+0.5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e356fe9",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "OnSight",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Loading