diff --git a/.gitignore b/.gitignore index 7b004e5..663e14a 100644 --- a/.gitignore +++ b/.gitignore @@ -191,4 +191,14 @@ cython_debug/ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data # refer to https://docs.cursor.com/context/ignore-files .cursorignore -.cursorindexingignore \ No newline at end of file +.cursorindexingignore + + +# custom ignores +mlruns/ +.DS_store +*.png +*.jpeg +*.csv +*.xmp +*.dng \ No newline at end of file diff --git a/0_align_images.ipynb b/0_align_images.ipynb new file mode 100644 index 0000000..125aeff --- /dev/null +++ b/0_align_images.ipynb @@ -0,0 +1,133 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "ac7bbf09", + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import numpy as np\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da5ede59", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.align_images import AlignImages\n", + "\n", + "from src.training.load_config import load_config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d50e4c7", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "path = Path(run_config['jpeg_output_subdir'])\n", + "outpath = Path(run_config['base_data_dir']) / run_config['jpeg_output_subdir']\n", + "secondary_align_csv = outpath / run_config['secondary_align_csv']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19e7da21", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = AlignImages(outpath, run_config['align_csv'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c549b157", + "metadata": {}, + "outputs": [], + "source": [ + "metric_list = []\n", + "for i in tqdm(range(len(dataset))):\n", + " gt, noisy, aligned, metrics = dataset[i]\n", + " metric_list.append(metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5acbb11e", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "df = pd.DataFrame(metric_list)\n", + "df.to_csv(secondary_align_csv)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ec8ee48", + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv(secondary_align_csv)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "192ded31", + "metadata": {}, + "outputs": [], + "source": [ + "(df.PSNR_after - df.PSNR_before).hist(bins = np.linspace(-1, 4, 100))\n", + "plt.ylabel('Images')\n", + "plt.xlabel('Delta PSNR')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c04b9532", + "metadata": {}, + "outputs": [], + "source": [ + "(df.SSIM_after - df.SSIM_before).hist(bins = np.linspace(-.01, .1, 100))\n", + "plt.ylabel('Images')\n", + "plt.xlabel('Delta SSIM')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/0_produce_small_dataset.ipynb b/0_produce_small_dataset.ipynb new file mode 100644 index 0000000..d1b07e9 --- /dev/null +++ b/0_produce_small_dataset.ipynb @@ -0,0 +1,670 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "23c4096e", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import imageio\n", + "import cv2\n", + "import matplotlib.pyplot as plt\n", + "import os\n", + "from pathlib import Path\n", + "import re\n", + "from PIL import Image\n", + "import re\n", + "from collections import defaultdict\n", + "\n", + "from colour_demosaicing import (\n", + " ROOT_RESOURCES_EXAMPLES,\n", + " demosaicing_CFA_Bayer_bilinear,\n", + " demosaicing_CFA_Bayer_Malvar2004,\n", + " demosaicing_CFA_Bayer_Menon2007,\n", + " mosaicing_CFA_Bayer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b6231f2", + "metadata": {}, + "outputs": [], + "source": [ + "from RawHandler.RawHandler import RawHandler\n", + "from RawHandler.utils import linear_to_srgb\n", + "from src.training.load_config import load_config\n", + "\n", + "def apply_gamma(x, gamma=2.2):\n", + " return x ** (1 / gamma)\n", + "\n", + "def reverse_gamma(x, gamma=2.2):\n", + " return x ** gamma" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2e21592", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "raw_path = Path(run_config['base_data_dir'])\n", + "outpath = Path(run_config['jpeg_output_subdir'])\n", + "alignment_csv = outpath / run_config['align_csv']\n", + "outpath_cropped = run_config['cropped_jpeg_subdir']\n", + "colorspace = run_config['colorspace']\n", + "\n", + "file_list = os.listdir(raw_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "783efea9", + "metadata": {}, + "outputs": [], + "source": [ + "def pair_images_by_scene(file_list, min_iso=100):\n", + " \"\"\"\n", + " Given a list of RAW image file paths:\n", + " 1. Extract ISO from filenames\n", + " 2. Remove files with ISO < min_iso\n", + " 3. Group by scene name\n", + " 4. Pair each image with the lowest-ISO version of the scene\n", + "\n", + " Args:\n", + " file_list (list of str): Paths to RAW files\n", + " min_iso (int): Minimum ISO to keep (default=100)\n", + "\n", + " Returns:\n", + " dict: {scene_name: [(img_path, gt_path), ...]}\n", + " \"\"\"\n", + " iso_pattern = re.compile(r\"_ISO(\\d+)_\")\n", + " scene_pairs = {}\n", + "\n", + " # Step 1: Extract iso and scene\n", + " images = []\n", + " for path in file_list:\n", + " filename = os.path.basename(path)\n", + " match = iso_pattern.search(filename)\n", + " if not match:\n", + " continue # skip if no ISO\n", + " iso = int(match.group(1))\n", + " if iso < min_iso:\n", + " continue # filter out low ISOs\n", + "\n", + " # Extract scene name:\n", + " if \"_GT_\" in filename:\n", + " scene = filename.split(\"_GT_\")[0]\n", + " else:\n", + " # Scene = part before \"_ISO\"\n", + " scene = filename.split(\"_ISO\")[0]\n", + " if 'X-Trans' in filename:\n", + " continue\n", + "\n", + " images.append((scene, iso, path))\n", + "\n", + " # Step 2: Group by scene\n", + " grouped = defaultdict(list)\n", + " for scene, iso, path in images:\n", + " grouped[scene].append((iso, path))\n", + "\n", + " # Step 3: For each scene, pick lowest ISO as GT\n", + " for scene, iso_paths in grouped.items():\n", + " iso_paths.sort(key=lambda x: x[0]) # sort by ISO ascending\n", + " gt_iso, gt_path = iso_paths[0] # lowest ISO ≥ min_iso\n", + " pairs = [(path, gt_path) for iso, path in iso_paths if path != gt_path]\n", + " scene_pairs[scene] = pairs\n", + "\n", + " return scene_pairs\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7f13812", + "metadata": {}, + "outputs": [], + "source": [ + "def get_initial_warp_matrix(img1_gray, img2_gray, num_features=2000):\n", + " \"\"\"\n", + " Finds an initial warp matrix using ORB feature matching.\n", + "\n", + " Args:\n", + " img1_gray (np.array): The first grayscale image (template).\n", + " img2_gray (np.array): The second grayscale image (to be warped).\n", + " num_features (int): The number of features for ORB to detect.\n", + "\n", + " Returns:\n", + " np.array: The 2x3 Euclidean warp matrix, or the identity matrix if it fails.\n", + " \"\"\"\n", + " try:\n", + " # Initialize ORB detector\n", + " orb = cv2.ORB_create(nfeatures=num_features)\n", + "\n", + " # Find the keypoints and descriptors with ORB\n", + " keypoints1, descriptors1 = orb.detectAndCompute(img1_gray, None)\n", + " keypoints2, descriptors2 = orb.detectAndCompute(img2_gray, None)\n", + " \n", + " # Descriptors can be None if no keypoints are found\n", + " if descriptors1 is None or descriptors2 is None:\n", + " return np.eye(2, 3, dtype=np.float32)\n", + "\n", + " # Create BFMatcher object\n", + " # NORM_HAMMING is used for binary descriptors like ORB\n", + " bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)\n", + "\n", + " # Match descriptors\n", + " matches = bf.match(descriptors1, descriptors2)\n", + "\n", + " # Sort them in the order of their distance (best matches first)\n", + " matches = sorted(matches, key=lambda x: x.distance)\n", + "\n", + " # Keep only the top matches (e.g., top 50 or 15% of matches)\n", + " num_good_matches = min(len(matches), 50)\n", + " if num_good_matches < 10: # Need at least ~6-10 points for a robust estimate\n", + " return np.eye(2, 3, dtype=np.float32)\n", + " \n", + " matches = matches[:num_good_matches]\n", + "\n", + " # Extract location of good matches\n", + " points1 = np.zeros((len(matches), 2), dtype=np.float32)\n", + " points2 = np.zeros((len(matches), 2), dtype=np.float32)\n", + "\n", + " for i, match in enumerate(matches):\n", + " points1[i, :] = keypoints1[match.queryIdx].pt\n", + " points2[i, :] = keypoints2[match.trainIdx].pt\n", + "\n", + " # Find the rigid transformation (Euclidean) using RANSAC\n", + " # cv2.estimateAffinePartial2D is perfect for finding a Euclidean transform\n", + " warp_matrix, _ = cv2.estimateAffinePartial2D(points2, points1, method=cv2.RANSAC)\n", + " \n", + " # If estimation fails, it returns None\n", + " if warp_matrix is None:\n", + " return np.eye(2, 3, dtype=np.float32)\n", + "\n", + " return warp_matrix.astype(np.float32)\n", + "\n", + " except cv2.error as e:\n", + " print(f\"OpenCV error during feature matching: {e}\")\n", + " return np.eye(2, 3, dtype=np.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7d9504cd", + "metadata": {}, + "outputs": [], + "source": [ + "def save_warp_dataframe(warp_matrix):\n", + " \"\"\"\n", + " Save warp matrix + metadata into a CSV with pandas.\n", + " warp_matrix: 2x3 or 3x3 numpy array\n", + " metadata: dict of other info\n", + " \"\"\"\n", + " flat = warp_matrix.flatten()\n", + " cols = [f\"m{i}{j}\" for i in range(warp_matrix.shape[0]) for j in range(warp_matrix.shape[1])]\n", + " row = dict(zip(cols, flat))\n", + " return row" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "419b9857", + "metadata": {}, + "outputs": [], + "source": [ + "def get_align_hybrid(noisy_fname, gt_fname, path, downsample_factor=4):\n", + " \"\"\"\n", + " Hybrid function to align images using feature-based pre-alignment (coarse)\n", + " and ECC (fine).\n", + " \"\"\"\n", + " # 1. Load raw files\n", + " noisy_handler = RawHandler(f'{path}/{noisy_fname}', colorspace=colorspace)\n", + " gt_handler = RawHandler(f'{path}/{gt_fname}', colorspace=colorspace)\n", + "\n", + " noisy_bayer = noisy_handler.apply_colorspace_transform(colorspace='lin_rec2020', clip=True).astype(np.float32)\n", + " gt_bayer = gt_handler.apply_colorspace_transform(colorspace='lin_rec2020', clip=True).astype(np.float32)\n", + " noisy_bayer = apply_gamma(noisy_bayer)\n", + " gt_bayer = apply_gamma(gt_bayer)\n", + " \n", + " noisy_image = demosaicing_CFA_Bayer_Malvar2004(noisy_bayer)\n", + " gt_image = demosaicing_CFA_Bayer_Malvar2004(gt_bayer)\n", + " noisy_image = np.clip(noisy_image, 0, 1)\n", + " gt_image = np.clip(gt_image, 0, 1)\n", + "\n", + "\n", + " # Note: OpenCV expects BGR order, RawHandler might give RGB. Ensure consistency.\n", + " # Assuming BGR for cvtColor. If RGB, use cv2.COLOR_RGB2GRAY.\n", + " gt_image_uint8 = (gt_image * 255.0).clip(0, 255).astype(np.uint8)\n", + " noisy_image_uint8 = (noisy_image * 255.0).clip(0, 255).astype(np.uint8)\n", + "\n", + " # 3. Convert to grayscale using the faster uint8 versions\n", + " noisy_gray = cv2.cvtColor(noisy_image_uint8, cv2.COLOR_BGR2GRAY)\n", + " gt_gray = cv2.cvtColor(gt_image_uint8, cv2.COLOR_BGR2GRAY)\n", + " h, w = noisy_gray.shape\n", + "\n", + " # 4. --- NEW: Get initial warp matrix from feature matching ---\n", + " # We run this on the full-res grayscale images for better keypoint detection\n", + " warp_matrix = get_initial_warp_matrix(gt_gray, noisy_gray)\n", + "\n", + " # 5. Downsample for ECC refinement\n", + " if downsample_factor > 1:\n", + " # We need to scale the initial guess to the downsampled size\n", + " warp_matrix_scaled = warp_matrix.copy()\n", + " warp_matrix_scaled[0, 2] /= downsample_factor\n", + " warp_matrix_scaled[1, 2] /= downsample_factor\n", + "\n", + " noisy_gray_small = cv2.resize(noisy_gray, (w // downsample_factor, h // downsample_factor), interpolation=cv2.INTER_AREA)\n", + " gt_gray_small = cv2.resize(gt_gray, (w // downsample_factor, h // downsample_factor), interpolation=cv2.INTER_AREA)\n", + " else:\n", + " noisy_gray_small = noisy_gray\n", + " gt_gray_small = gt_gray\n", + " warp_matrix_scaled = warp_matrix\n", + "\n", + " # 6. ECC alignment (refinement) using the improved initial guess\n", + " warp_mode = cv2.MOTION_EUCLIDEAN\n", + " criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 200, 1e-5)\n", + " # We provide `warp_matrix_scaled` as the initial guess!\n", + " try:\n", + "\n", + " (cc, warp_matrix_final_scaled) = cv2.findTransformECC(gt_gray_small, noisy_gray_small, warp_matrix_scaled, warp_mode, criteria)\n", + " except cv2.error:\n", + " # If ECC fails, use the initial matrix from feature matching\n", + " cc = -1.0 # Indicate failure or that we used the fallback\n", + " warp_matrix_final_scaled = warp_matrix_scaled\n", + " \n", + " # 7. Scale the final matrix translation back to full resolution\n", + " \n", + " warp_matrix_final = warp_matrix_final_scaled.copy()\n", + " if downsample_factor > 1:\n", + " warp_matrix_final[0, 2] *= downsample_factor\n", + " warp_matrix_final[1, 2] *= downsample_factor\n", + "\n", + " # 8. Warp the original full-resolution FLOAT image\n", + " gt_aligned = cv2.warpAffine(gt_image, warp_matrix_final, (w, h), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP)\n", + "\n", + " # 9. Metadata / stats\n", + " try:\n", + " iso = float(re.findall('ISO([0-9]+)', noisy_fname)[0])\n", + " except (IndexError, ValueError):\n", + " iso = 0\n", + "\n", + " info_dict = {\n", + " \"cc\": cc,\n", + " \"noisy_image\": noisy_fname,\n", + " \"gt_image\": gt_fname,\n", + " \"gt_mean\": gt_image.mean(),\n", + " \"noisy_mean\": noisy_image.mean(),\n", + " \"noise_level\": (noisy_image-gt_image).std(axis=(0,1)),\n", + " **save_warp_dataframe(warp_matrix_final),\n", + " \"iso\": iso,\n", + " }\n", + "\n", + " return info_dict, noisy_image, gt_aligned, noisy_bayer[0], gt_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9af686c6", + "metadata": {}, + "outputs": [], + "source": [ + "pair_file_list = pair_images_by_scene(file_list)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b1d38d1", + "metadata": {}, + "outputs": [], + "source": [ + "def as_8bit(x):\n", + " return (x * 255).astype(np.uint8)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ebb696eb", + "metadata": {}, + "outputs": [], + "source": [ + "# Test loop so we can visualize the alignment performance\n", + "\n", + "list = []\n", + "idx = 0\n", + "for key in pair_file_list.keys():\n", + " image_pairs = pair_file_list[key]\n", + " print(idx, idx/len(pair_file_list))\n", + " idx+=1\n", + " jdx = 0\n", + " for (noise, gt) in image_pairs:\n", + " output, noisy_image, gt_aligned, noisy_bayer, gt_image = get_align_hybrid(noise, gt, path, downsample_factor=1)\n", + " break\n", + " break\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11269156", + "metadata": {}, + "outputs": [], + "source": [ + "plt.subplots(2, 3, figsize=(30, 20))\n", + "\n", + "plt.subplot(2,3,1)\n", + "plt.imshow(noisy_image[1000:1100, 3000:3100])\n", + "\n", + "plt.subplot(2,3,2)\n", + "plt.imshow(noisy_bayer[1000:1100, 3000:3100])\n", + "\n", + "plt.subplot(2,3, 3)\n", + "plt.imshow(gt_image[1000:1100, 3000:3100])\n", + "\n", + "\n", + "plt.subplot(2,3, 4)\n", + "plt.imshow(noisy_image[1000:1100, 3000:3100]-gt_image[1000:1100, 3000:3100]+0.5)\n", + "\n", + "\n", + "plt.subplot(2,3, 5)\n", + "plt.imshow(noisy_image[1000:1100, 3000:3100]-gt_aligned[1000:1100, 3000:3100]+0.5)\n", + "\n", + "plt.subplot(2,3, 6)\n", + "# plt.imshow(noisy_bayer[1000:1100, 3000:3100]-noisy_image[1000:1100, 3000:3100]+0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1e1066b", + "metadata": {}, + "outputs": [], + "source": [ + "# Align and save the jpegs\n", + "list = []\n", + "idx = 0\n", + "for key in pair_file_list.keys():\n", + " image_pairs = pair_file_list[key]\n", + " print(idx, idx/len(pair_file_list))\n", + " idx+=1\n", + " jdx = 0\n", + " for (noise, gt) in image_pairs:\n", + " try:\n", + " output, noisy_image, gt_aligned, noisy_bayer, gt_image = get_align_hybrid(noise, gt, path, downsample_factor=1)\n", + " list.append(output)\n", + " noisy_image\n", + " imageio.imwrite(f\"{outpath}/{noise}.jpg\", as_8bit(noisy_image), quality=100)\n", + " imageio.imwrite(f\"{outpath}/{noise}_bayer.jpg\", as_8bit(noisy_bayer), quality=100)\n", + " if jdx==0:\n", + " imageio.imwrite(f\"{outpath}/{gt}.jpg\", as_8bit(gt_image), quality=100)\n", + " jdx+=1\n", + " except:\n", + " print(f\"skipping {noise}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79bab144", + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.DataFrame(list)\n", + "df.to_csv(alignment_csv)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "337da98d", + "metadata": {}, + "outputs": [], + "source": [ + "###\n", + "## The following code allows for the existing dataset to be further reduced by cropping into center squares\n", + "###" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca7517af", + "metadata": {}, + "outputs": [], + "source": [ + "# Load saved dataset and visualize alignment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3452667b", + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv(alignment_csv)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9eb194bc", + "metadata": {}, + "outputs": [], + "source": [ + "row = df.iloc[-2]\n", + "\n", + "# Get Row Matrix\n", + "shape=(2,3)\n", + "cols = [f\"m{i}{j}\" for i in range(shape[0]) for j in range(shape[1])]\n", + "flat = np.array([row.pop(c) for c in cols], dtype=np.float32)\n", + "warp_matrix = flat.reshape(shape)\n", + "warp_matrix\n", + "\n", + "\n", + "noisy_name = row.noisy_image\n", + "gt_name = row.gt_image\n", + "\n", + "with imageio.imopen(f\"{outpath}/{noisy_name}_bayer.jpg\", \"r\") as image_resource:\n", + " bayer_data = image_resource.read()\n", + "\n", + "with imageio.imopen(f\"{outpath}/{noisy_name}.jpg\", \"r\") as image_resource:\n", + " noisy = image_resource.read()\n", + "\n", + "\n", + "with imageio.imopen(f\"{outpath}/{gt_name}.jpg\", \"r\") as image_resource:\n", + " gt_image = image_resource.read()\n", + "\n", + "noisy = noisy/255\n", + "gt_image = gt_image/255\n", + "bayer_data = bayer_data/255\n", + "h, w, _ = noisy.shape\n", + "gt = cv2.warpAffine(gt_image, warp_matrix, (w, h), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP)\n", + "\n", + "warp_matrix\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "621a322b", + "metadata": {}, + "outputs": [], + "source": [ + "demosaiced = demosaicing_CFA_Bayer_Malvar2004(bayer_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f8dbf1f", + "metadata": {}, + "outputs": [], + "source": [ + "plt.subplots(2, 3, figsize=(30, 20))\n", + "\n", + "plt.subplot(2,3,1)\n", + "plt.imshow(noisy[1000:1100, 3000:3100])\n", + "\n", + "plt.subplot(2,3,2)\n", + "plt.imshow(demosaiced[1000:1100, 3000:3100])\n", + "\n", + "plt.subplot(2,3, 3)\n", + "plt.imshow(gt[1000:1100, 3000:3100])\n", + "\n", + "\n", + "plt.subplot(2,3, 4)\n", + "plt.imshow(gt[1000:1100, 3000:3100]-demosaiced[1000:1100, 3000:3100]+0.5)\n", + "\n", + "\n", + "plt.subplot(2,3, 5)\n", + "plt.imshow(noisy[1000:1100, 3000:3100]-demosaiced[1000:1100, 3000:3100]+0.5) \n", + "\n", + "plt.subplot(2,3, 6)\n", + "plt.imshow(noisy[1000:1100, 3000:3100]-gt_image[1000:1100, 3000:3100]+0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a37b64b4", + "metadata": {}, + "outputs": [], + "source": [ + "##\n", + "## Crop to center to save even more space\n", + "##" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e3c61f1", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def crop_center_square(input_dir, output_dir, crop_size):\n", + " \"\"\"\n", + " Loops over image files in a directory, crops a center square, and saves them.\n", + "\n", + " Args:\n", + " input_dir (str): The path to the directory containing the original images.\n", + " output_dir (str): The path to the directory where cropped images will be saved.\n", + " crop_size (int): The width and height of the square to crop (must be an even number).\n", + " \"\"\"\n", + " # --- 1. Input Validation ---\n", + " if not os.path.isdir(input_dir):\n", + " print(f\"Error: Input directory not found at '{input_dir}'\")\n", + " return\n", + "\n", + " if crop_size % 2 != 0:\n", + " print(f\"Error: Crop size must be an even number. You provided {crop_size}.\")\n", + " return\n", + " \n", + " # --- 2. Create Output Directory ---\n", + " # Create the output directory if it doesn't already exist.\n", + " os.makedirs(output_dir, exist_ok=True)\n", + " print(f\"Output will be saved to: '{output_dir}'\")\n", + "\n", + " # --- 3. Process Images ---\n", + " # Get a list of all files in the input directory.\n", + " files = os.listdir(input_dir)\n", + "\n", + " processed_count = 0\n", + " for filename in files:\n", + " # Construct the full path for the input file.\n", + " input_path = os.path.join(input_dir, filename)\n", + "\n", + " # Process only files, not subdirectories.\n", + " if os.path.isfile(input_path):\n", + " try:\n", + " # Open the image using Pillow.\n", + " with Image.open(input_path) as img:\n", + " width, height = img.size\n", + "\n", + " # Check if the image is large enough to be cropped.\n", + " if width < crop_size or height < crop_size:\n", + " print(f\"Skipping '{filename}': smaller than crop size.\")\n", + " continue\n", + " \n", + " # Calculate the coordinates for the center crop.\n", + " left = (width - crop_size) // 2\n", + " top = (height - crop_size) // 2\n", + "\n", + " # Ensure the top-left corner is on an even pixel coordinate for Bayer alignment.\n", + " if left % 2 != 0:\n", + " left -= 1\n", + " if top % 2 != 0:\n", + " top -= 1\n", + " \n", + " # Calculate the bottom-right corner based on the adjusted top-left corner.\n", + " # Since crop_size is even, right and bottom will also be even.\n", + " right = left + crop_size\n", + " bottom = top + crop_size\n", + "\n", + " # Perform the crop. The box is a 4-tuple defining the left, upper, right, and lower pixel coordinate.\n", + " img_cropped = np.array(img.crop((left, top, right, bottom)))\n", + "\n", + " # Construct the full path for the output file.\n", + " output_path = os.path.join(output_dir, filename)\n", + " \n", + " # Save the cropped image.\n", + " # img_cropped.save(output_path)\n", + " imageio.imwrite(output_path,img_cropped, quality=95)\n", + " processed_count += 1\n", + "\n", + " except (IOError, OSError) as e:\n", + " # Handle cases where the file is not a valid image.\n", + " print(f\"Could not process '{filename}'. It might not be an image file. Error: {e}\")\n", + " except Exception as e:\n", + " print(f\"An unexpected error occurred with file '{filename}': {e}\")\n", + " \n", + " print(f\"\\nProcessing complete. Cropped {processed_count} images.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9e4cf92", + "metadata": {}, + "outputs": [], + "source": [ + "crop_center_square(outpath, outpath_cropped, crop_size=2000)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/0_produce_small_dataset_raw.ipynb b/0_produce_small_dataset_raw.ipynb new file mode 100644 index 0000000..48e83c2 --- /dev/null +++ b/0_produce_small_dataset_raw.ipynb @@ -0,0 +1,629 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "23c4096e", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import imageio\n", + "import cv2\n", + "import matplotlib.pyplot as plt\n", + "import os\n", + "from pathlib import Path\n", + "import re\n", + "from PIL import Image\n", + "import re\n", + "from collections import defaultdict\n", + "\n", + "from colour_demosaicing import (\n", + " ROOT_RESOURCES_EXAMPLES,\n", + " demosaicing_CFA_Bayer_bilinear,\n", + " demosaicing_CFA_Bayer_Malvar2004,\n", + " demosaicing_CFA_Bayer_Menon2007,\n", + " mosaicing_CFA_Bayer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b6231f2", + "metadata": {}, + "outputs": [], + "source": [ + "from RawHandler.RawHandler import RawHandler\n", + "from RawHandler.utils import linear_to_srgb\n", + "from src.training.load_config import load_config\n", + "\n", + "def apply_gamma(x, gamma=2.2):\n", + " return x ** (1 / gamma)\n", + "\n", + "def reverse_gamma(x, gamma=2.2):\n", + " return x ** gamma" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2e21592", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "raw_path = Path(run_config['base_data_dir'])\n", + "outpath = Path(run_config['cropped_raw_subdir'])\n", + "alignment_csv = outpath / run_config['align_csv']\n", + "colorspace = run_config['colorspace']\n", + "crop_size = run_config['cropped_raw_size']\n", + "file_list = os.listdir(raw_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "783efea9", + "metadata": {}, + "outputs": [], + "source": [ + "def pair_images_by_scene(file_list, min_iso=100):\n", + " \"\"\"\n", + " Given a list of RAW image file paths:\n", + " 1. Extract ISO from filenames\n", + " 2. Remove files with ISO < min_iso\n", + " 3. Group by scene name\n", + " 4. Pair each image with the lowest-ISO version of the scene\n", + "\n", + " Args:\n", + " file_list (list of str): Paths to RAW files\n", + " min_iso (int): Minimum ISO to keep (default=100)\n", + "\n", + " Returns:\n", + " dict: {scene_name: [(img_path, gt_path), ...]}\n", + " \"\"\"\n", + " iso_pattern = re.compile(r\"_ISO(\\d+)_\")\n", + " scene_pairs = {}\n", + "\n", + " # Step 1: Extract iso and scene\n", + " images = []\n", + " for path in file_list:\n", + " filename = os.path.basename(path)\n", + " match = iso_pattern.search(filename)\n", + " if not match:\n", + " continue # skip if no ISO\n", + " iso = int(match.group(1))\n", + " if iso < min_iso:\n", + " continue # filter out low ISOs\n", + "\n", + " # Extract scene name:\n", + " if \"_GT_\" in filename:\n", + " scene = filename.split(\"_GT_\")[0]\n", + " else:\n", + " # Scene = part before \"_ISO\"\n", + " scene = filename.split(\"_ISO\")[0]\n", + " if 'X-Trans' in filename:\n", + " continue\n", + "\n", + " images.append((scene, iso, path))\n", + "\n", + " # Step 2: Group by scene\n", + " grouped = defaultdict(list)\n", + " for scene, iso, path in images:\n", + " grouped[scene].append((iso, path))\n", + "\n", + " # Step 3: For each scene, pick lowest ISO as GT\n", + " for scene, iso_paths in grouped.items():\n", + " iso_paths.sort(key=lambda x: x[0]) # sort by ISO ascending\n", + " gt_iso, gt_path = iso_paths[0] # lowest ISO ≥ min_iso\n", + " pairs = [(path, gt_path) for iso, path in iso_paths if path != gt_path]\n", + " scene_pairs[scene] = pairs\n", + "\n", + " return scene_pairs\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9af686c6", + "metadata": {}, + "outputs": [], + "source": [ + "pair_file_list = pair_images_by_scene(file_list)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "206ab100", + "metadata": {}, + "outputs": [], + "source": [ + "def get_file(impath, crop_size=crop_size):\n", + " rh = RawHandler(impath)\n", + " \n", + " width, height = rh.raw.shape\n", + "\n", + " # Check if the image is large enough to be cropped.\n", + " if width < crop_size or height < crop_size:\n", + " im = rh.apply_colorspace_transform(colorspace=colorspace)\n", + " else:\n", + " # Calculate the coordinates for the center crop.\n", + " left = (width - crop_size) // 2\n", + " top = (height - crop_size) // 2\n", + "\n", + " # Ensure the top-left corner is on an even pixel coordinate for Bayer alignment.\n", + " if left % 2 != 0:\n", + " left -= 1\n", + " if top % 2 != 0:\n", + " top -= 1\n", + " \n", + " # Calculate the bottom-right corner based on the adjusted top-left corner.\n", + " # Since crop_size is even, right and bottom will also be even.\n", + " right = left + crop_size\n", + " bottom = top + crop_size\n", + " return rh.raw[left:right, top:bottom], rh" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c29f0629", + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5b63be2", + "metadata": {}, + "outputs": [], + "source": [ + "from pidng.core import RAW2DNG, DNGTags, Tag\n", + "from pidng.defs import *\n", + "\n", + "def get_ratios(string, rh):\n", + " return [x.as_integer_ratio() for x in rh.full_metadata[string].values]\n", + "\n", + "\n", + "def rational_wb(rh, denominator=1000):\n", + " wb = np.array(rh.core_metadata.camera_white_balance)\n", + " numerator_matrix = np.round(wb * denominator).astype(int)\n", + " return [[num, denominator] for num in numerator_matrix]\n", + "def convert_ccm_to_rational(matrix_3x3, denominator=10000):\n", + "\n", + " numerator_matrix = np.round(matrix_3x3 * denominator).astype(int)\n", + " numerators_flat = numerator_matrix.flatten()\n", + " ccm_rational = [[num, denominator] for num in numerators_flat]\n", + " \n", + " return ccm_rational\n", + "\n", + "\n", + "def get_as_shot_neutral(rh, denominator=10000):\n", + "\n", + " cam_mul = rh.core_metadata.camera_white_balance\n", + " \n", + " if cam_mul[0] == 0 or cam_mul[2] == 0:\n", + " return [[denominator, denominator], [denominator, denominator], [denominator, denominator]]\n", + "\n", + " r_neutral = cam_mul[1] / cam_mul[0]\n", + " g_neutral = 1.0 \n", + " b_neutral = cam_mul[1] / cam_mul[2]\n", + "\n", + " return [\n", + " [int(r_neutral * denominator), denominator],\n", + " [int(g_neutral * denominator), denominator],\n", + " [int(b_neutral * denominator), denominator],\n", + " ]\n", + "\n", + "\n", + "def to_dng(uint_img, rh, filepath):\n", + " width = uint_img.shape[1]\n", + " height = uint_img.shape[0]\n", + " bpp = 16\n", + "\n", + " ccm1 = convert_ccm_to_rational(rh.core_metadata.rgb_xyz_matrix[:3, :])\n", + " t = DNGTags()\n", + " t.set(Tag.ImageWidth, width)\n", + " t.set(Tag.ImageLength, height)\n", + " t.set(Tag.TileWidth, width)\n", + " t.set(Tag.TileLength, height)\n", + " t.set(Tag.BitsPerSample, bpp)\n", + "\n", + " t.set(Tag.SamplesPerPixel, 1) \n", + " t.set(Tag.PlanarConfiguration, 1) \n", + "\n", + " t.set(Tag.TileWidth, width)\n", + " t.set(Tag.TileLength, height)\n", + " t.set(Tag.PhotometricInterpretation, PhotometricInterpretation.Color_Filter_Array)\n", + " t.set(Tag.CFARepeatPatternDim, [2,2])\n", + " t.set(Tag.CFAPattern, CFAPattern.RGGB)\n", + " bl = rh.core_metadata.black_level_per_channel\n", + " t.set(Tag.BlackLevelRepeatDim, [2,2])\n", + " t.set(Tag.BlackLevel, bl)\n", + " t.set(Tag.WhiteLevel, rh.core_metadata.white_level)\n", + "\n", + " t.set(Tag.BitsPerSample, bpp)\n", + "\n", + " t.set(Tag.ColorMatrix1, ccm1)\n", + " t.set(Tag.CalibrationIlluminant1, CalibrationIlluminant.D65)\n", + " wb = get_as_shot_neutral(rh)\n", + " t.set(Tag.AsShotNeutral, wb)\n", + " t.set(Tag.BaselineExposure, [[0,100]])\n", + "\n", + "\n", + "\n", + " try:\n", + " t.set(Tag.Make, rh.full_metadata['Image Make'].values)\n", + " t.set(Tag.Model, rh.full_metadata['Image Model'].values)\n", + " exposures = get_ratios('EXIF ExposureTime', rh)\n", + " fnumber = get_ratios('EXIF FNumber', rh)\n", + " ExposureBiasValue = get_ratios('EXIF ExposureBiasValue', rh) \n", + " FocalLength = get_ratios('EXIF FocalLength', rh) \n", + " t.set(Tag.FocalLength, FocalLength)\n", + " t.set(Tag.EXIFPhotoLensModel, rh.full_metadata['EXIF LensModel'].values)\n", + " t.set(Tag.ExposureBiasValue, ExposureBiasValue)\n", + " t.set(Tag.ExposureTime, exposures)\n", + " t.set(Tag.FNumber, fnumber)\n", + " t.set(Tag.PhotographicSensitivity, rh.full_metadata['EXIF ISOSpeedRatings'].values)\n", + " t.set(Tag.Orientation, rh.full_metadata['Image Orientation'].values[0])\n", + " except:\n", + " \"ok\"\n", + " t.set(Tag.DNGVersion, DNGVersion.V1_4)\n", + " t.set(Tag.DNGBackwardVersion, DNGVersion.V1_2)\n", + " t.set(Tag.PreviewColorSpace, PreviewColorSpace.Adobe_RGB)\n", + "\n", + " r = RAW2DNG()\n", + "\n", + " r.options(t, path=\"\", compress=False)\n", + "\n", + " r.convert(uint_img, filename=filepath)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07151d60", + "metadata": {}, + "outputs": [], + "source": [ + "for key in tqdm(pair_file_list.keys()):\n", + " image_pairs = pair_file_list[key]\n", + " for noisy, gt in image_pairs:\n", + " noisy_path = outpath / (noisy)\n", + " if not os.path.exists(str(noisy_path)+'.dng'):\n", + " print(noisy_path)\n", + " if noisy.endswith(('.cr2', '.nef', '.arw', '.orf', '.raf', '.pef', '.crw', '.dng')):\n", + " bayer, rh = get_file(f'{raw_path}/{noisy}')\n", + " to_dng(bayer, rh, str(noisy_path))\n", + "\n", + "\n", + " gt_path = outpath / (gt)\n", + " if not os.path.exists(str(gt_path)+'.dng'):\n", + " print(gt_path)\n", + " bayer, rh = get_file(f'{raw_path}/{gt}')\n", + " to_dng(bayer, rh, str(gt_path))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9414c7a1", + "metadata": {}, + "outputs": [], + "source": [ + "#Testing data is properly copied" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ad377b4", + "metadata": {}, + "outputs": [], + "source": [ + "rhdng = RawHandler(str(outpath / \"Bayer_MuseeL-sol-A7C-brighter_ISO100_sha1=18eaa9931d9a0f6f0511552ef6bf2fd040d82878.arw.dng\"))\n", + "rh = RawHandler('/Volumes/EasyStore/RAWNIND/Bayer_MuseeL-sol-A7C-brighter_ISO100_sha1=18eaa9931d9a0f6f0511552ef6bf2fd040d82878.arw')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d95fa141", + "metadata": {}, + "outputs": [], + "source": [ + "imdng = rhdng.as_rgb()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45796c08", + "metadata": {}, + "outputs": [], + "source": [ + "width, height = rh.raw.shape\n", + "\n", + "left = (width - crop_size) // 2\n", + "top = (height - crop_size) // 2\n", + "\n", + "if left % 2 != 0:\n", + " left -= 1\n", + "if top % 2 != 0:\n", + " top -= 1\n", + "\n", + "right = left + crop_size\n", + "bottom = top + crop_size\n", + "\n", + "im = rh.as_rgb(dims=(left, right, top, bottom))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "406c5201", + "metadata": {}, + "outputs": [], + "source": [ + "(imdng-im).mean()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c495ba8", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import os\n", + "from torch.utils.data import Dataset\n", + "import imageio\n", + "from colour_demosaicing import (\n", + " ROOT_RESOURCES_EXAMPLES,\n", + " demosaicing_CFA_Bayer_bilinear,\n", + " demosaicing_CFA_Bayer_Malvar2004,\n", + " demosaicing_CFA_Bayer_Menon2007,\n", + " mosaicing_CFA_Bayer)\n", + "\n", + "from src.training.utils import inverse_gamma_tone_curve, cfa_to_sparse\n", + "import numpy as np\n", + "import torch \n", + "from src.training.align_images import apply_alignment, align_clean_to_noisy\n", + "from pathlib import Path\n", + "from RawHandler.RawHandler import RawHandler\n", + "\n", + "\n", + "\n", + "def global_affine_match(A, D, mask=None):\n", + " \"\"\"\n", + " Fit D ≈ a + b*A with least squares.\n", + " A, D : 2D arrays, same shape (linear values)\n", + " mask : optional boolean array, True=use pixel\n", + " returns: a, b, D_pred, D_resid (D - (a + b*A))\n", + " \"\"\"\n", + " A = A.ravel().astype(np.float64)\n", + " D = D.ravel().astype(np.float64)\n", + " if mask is None:\n", + " mask = np.isfinite(A) & np.isfinite(D)\n", + " else:\n", + " mask = mask.ravel() & np.isfinite(A) & np.isfinite(D)\n", + "\n", + " A0 = A[mask]\n", + " D0 = D[mask]\n", + " # design matrix [1, A]\n", + " X = np.vstack([np.ones_like(A0), A0]).T\n", + " coef, *_ = np.linalg.lstsq(X, D0, rcond=None)\n", + " a, b = coef[0], coef[1]\n", + " D_pred = (a + b * A).reshape(-1)\n", + " D_pred = D_pred.reshape(A.shape) if False else (a + b * A).reshape((-1,)) # keep flatten\n", + "\n", + " return a, b, (a + b * A)\n", + "\n", + "\n", + "def random_crop_dim(shape, crop_size, buffer, validation=False):\n", + " h, w = shape\n", + " if not validation:\n", + " top = np.random.randint(0 + buffer, h - crop_size - buffer)\n", + " left = np.random.randint(0 + buffer, w - crop_size - buffer)\n", + " else:\n", + " top = (h - crop_size) // 2\n", + " left = (w - crop_size) // 2\n", + "\n", + " if top % 2 != 0: top = top - 1\n", + " if left % 2 != 0: left = left - 1\n", + " bottom = top + crop_size\n", + " right = left + crop_size\n", + " return (left, right, top, bottom)\n", + "\n", + "class RawDatasetDNG(Dataset):\n", + " def __init__(self, path, csv, colorspace, crop_size=180, buffer=10, validation=False, run_align=False, dimensions=2000):\n", + " super().__init__()\n", + " self.df = pd.read_csv(csv)\n", + " self.path = path\n", + " self.crop_size = crop_size\n", + " self.buffer = buffer\n", + " self.coordinate_iso = 6400\n", + " self.validation=validation\n", + " self.run_align = run_align\n", + " self.dtype = np.float16\n", + " self.dimensions = dimensions\n", + " self.colorspace = colorspace\n", + "\n", + " def __len__(self):\n", + " return len(self.df)\n", + " \n", + " def __getitem__(self, idx):\n", + " row = self.df.iloc[idx]\n", + " # Load images\n", + " name = Path(f\"{row.bayer_path}\").name\n", + " name = str(self.path / name.replace('_bayer.jpg', '.dng'))\n", + " noisy_rh = RawHandler(name)\n", + " \n", + " name = Path(f\"{row.gt_path}\").name\n", + " name = str(self.path / name.replace('.jpg', '.dng'))\n", + " gt_rh = RawHandler(name)\n", + "\n", + "\n", + " dims = random_crop_dim(noisy_rh.raw.shape, self.crop_size, self.buffer, validation=self.validation)\n", + " bayer_data = noisy_rh.apply_colorspace_transform(dims=dims, colorspace=self.colorspace)\n", + " noisy = noisy_rh.as_rgb(dims=dims, colorspace=self.colorspace)\n", + " rggb = noisy_rh.as_rggb(dims=dims, colorspace=self.colorspace)\n", + "\n", + " expanded_dims = [dims[0]-self.buffer, dims[1]+self.buffer, dims[0]-self.buffer, dims[1]+self.buffer]\n", + " gt_expanded = gt_rh.as_rgb(dims=expanded_dims, colorspace=self.colorspace)\n", + " aligned = apply_alignment(gt_expanded.transpose(1, 2, 0), row.to_dict())[self.buffer:-self.buffer, self.buffer:-self.buffer]\n", + " gt_non_aligned = gt_expanded.transpose(1, 2, 0)[self.buffer:-self.buffer, self.buffer:-self.buffer]\n", + " # Convert to tensors\n", + " output = {\n", + " \"bayer\": torch.tensor(bayer_data).to(float).clip(0,1), \n", + " \"gt_non_aligned\": torch.tensor(gt_non_aligned).to(float).permute(2, 0, 1).clip(0,1), \n", + " \"aligned\": torch.tensor(aligned).to(float).permute(2, 0, 1).clip(0,1), \n", + " # \"sparse\": torch.tensor(sparse).to(float).clip(0,1),\n", + " \"noisy\": torch.tensor(noisy).to(float).clip(0,1), \n", + " \"rggb\": torch.tensor(rggb).to(float).clip(0,1),\n", + " \"conditioning\": torch.tensor([row.iso/self.coordinate_iso]).to(float), \n", + " # \"noise_est\": noise_est,\n", + " # \"rggb_gt\": rggb_gt,\n", + " }\n", + " return output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04525aa2", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.load_config import load_config\n", + "\n", + "run_config = load_config()\n", + "dataset_path = Path(run_config['cropped_raw_subdir'])\n", + "align_csv = dataset_path / run_config['secondary_align_csv']\n", + "\n", + "\n", + "device=run_config['device']\n", + "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", + "\n", + "num_epochs = run_config['num_epochs_pretraining']\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "mlflow_path = run_config['mlflow_path']\n", + "colorspace = run_config['colorspace']\n", + "\n", + "rggb = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c529a7f1", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = SmallRawDatasetNumpy(dataset_path, align_csv, colorspace, crop_size=crop_size, validation=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2b506c1", + "metadata": {}, + "outputs": [], + "source": [ + "output = dataset[0]\n", + "import matplotlib.pyplot as plt\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5ee6bbe", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.imshow(output['noisy'].permute(1, 2, 0)**(1/2.2))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e05c86a", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.imshow((output['noisy']-output['aligned']).permute(1, 2, 0)+0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6036f9a3", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.imshow((output['noisy']-output['gt_non_aligned']).permute(1, 2, 0)+0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "203687d3", + "metadata": {}, + "outputs": [], + "source": [ + "output['gt_non_aligned'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f684a2f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "819fa74d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/1_pretrain_model.ipynb b/1_pretrain_model.ipynb new file mode 100644 index 0000000..1ccfd02 --- /dev/null +++ b/1_pretrain_model.ipynb @@ -0,0 +1,205 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6351e77", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader, random_split\n", + "import torch.nn as nn\n", + "import torch\n", + "import copy\n", + "import mlflow\n", + "import mlflow.pytorch\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2043dc8e", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.SmallRawDataset import SmallRawDataset\n", + "from src.training.losses.ShadowAwareLoss import ShadowAwareLoss\n", + "from src.training.VGGFeatureExtractor import VGGFeatureExtractor\n", + "from src.training.train_loop import train_one_epoch, visualize\n", + "from src.training.utils import apply_gamma_torch\n", + "from src.training.load_config import load_config\n", + "from src.Restorer.Cond_NAF import make_full_model_RGGB\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0464c98", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "dataset_path = Path(run_config['jpeg_output_subdir'])\n", + "align_csv = dataset_path / run_config['secondary_align_csv']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba20b866", + "metadata": {}, + "outputs": [], + "source": [ + "device=run_config['device']\n", + "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", + "\n", + "num_epochs = run_config['num_epochs_pretraining']\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "mlflow_path = run_config['mlflow_path']\n", + "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", + "mlflow.set_experiment(experiment)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a36eaef2", + "metadata": {}, + "outputs": [], + "source": [ + "model_params = run_config['model_params']\n", + "rggb = model_params['rggb']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f16fa7", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = SmallRawDataset(dataset_path, align_csv, crop_size=crop_size)\n", + "\n", + "# Split dataset into train and val\n", + "val_size = int(len(dataset) * val_split)\n", + "train_size = len(dataset) - val_size\n", + "torch.manual_seed(42) # For reproducibility\n", + "train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n", + "# Set the validation dataset to use the same crops\n", + "val_dataset = copy.deepcopy(val_dataset)\n", + "val_dataset.dataset.validation = True\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "086fbb3a", + "metadata": {}, + "outputs": [], + "source": [ + "model = make_full_model_RGGB(model_params, model_name=None)\n", + "model = model.to(device)\n", + "\n", + "params = {**run_config, **model_params}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6af0f3a2", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", + "\n", + "vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), \n", + " feature_layers=[14], \n", + " activation=nn.ReLU\n", + " )\n", + "vfe = vfe.to(device)\n", + "\n", + "loss_fn = ShadowAwareLoss(\n", + " alpha=run_config['alpha'],\n", + " beta=run_config['beta'],\n", + " l1_weight=run_config['l1_weight'],\n", + " ssim_weight=run_config['ssim_weight'],\n", + " tv_weight=run_config['tv_weight'],\n", + " vgg_loss_weight=run_config['vgg_loss_weight'],\n", + " apply_gamma_fn=apply_gamma_torch,\n", + " vgg_feature_extractor=vfe,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bc86a06", + "metadata": {}, + "outputs": [], + "source": [ + "with mlflow.start_run(run_name=run_config['run_name']) as run:\n", + "\n", + " mlflow.log_params(params)\n", + " for epoch in range(num_epochs):\n", + " train_one_epoch(epoch, model, optimizer, train_loader, device, loss_fn, clipping, \n", + " log_interval = 10, sleep=0.0, rggb=rggb, max_batches=0)\n", + " \n", + " mlflow.pytorch.log_model(\n", + " pytorch_model=model,\n", + " name=run_config['run_path'],\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f95b42f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a775d31", + "metadata": {}, + "outputs": [], + "source": [ + "run.info.run_id" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/1_train_model_raw.ipynb b/1_train_model_raw.ipynb new file mode 100644 index 0000000..f68ce2e --- /dev/null +++ b/1_train_model_raw.ipynb @@ -0,0 +1,216 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6351e77", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader, random_split\n", + "import torch.nn as nn\n", + "import torch\n", + "import copy\n", + "import mlflow\n", + "import mlflow.pytorch\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2043dc8e", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.RawDatasetDNG import RawDatasetDNG\n", + "from src.training.losses.ShadowAwareLoss import ShadowAwareLoss\n", + "from src.training.VGGFeatureExtractor import VGGFeatureExtractor\n", + "from src.training.train_loop import train_one_epoch, visualize\n", + "from src.training.utils import apply_gamma_torch\n", + "from src.training.load_config import load_config\n", + "from src.Restorer.Cond_NAF import make_full_model_RGGB\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0464c98", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "dataset_path = Path(run_config['cropped_raw_subdir'])\n", + "align_csv = dataset_path / run_config['secondary_align_csv']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7322456f", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba20b866", + "metadata": {}, + "outputs": [], + "source": [ + "device=run_config['device']\n", + "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", + "\n", + "num_epochs = run_config['num_epochs_pretraining']\n", + "cosine_annealing = run_config['cosine_annealing']\n", + "\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "mlflow_path = run_config['mlflow_path']\n", + "colorspace = run_config['colorspace']\n", + "iso_range = run_config['iso_range']\n", + "\n", + "rggb = True\n", + "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", + "mlflow.set_experiment(experiment)\n", + "\n", + "params = {**run_config}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9a26124", + "metadata": {}, + "outputs": [], + "source": [ + "model_params = run_config['model_params']\n", + "rggb = model_params['rggb']\n", + "\n", + "model = make_full_model_RGGB(model_params, model_name=None)\n", + "model = model.to(device)\n", + "\n", + "params = {**run_config, **model_params}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f16fa7", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = RawDatasetDNG(dataset_path, align_csv, colorspace, crop_size=crop_size)\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('crw')]\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('dng_bayer')]\n", + "dataset.df = dataset.df[(dataset.df.iso >= iso_range[0]) & (dataset.df.iso <= iso_range[1])]\n", + "print(len(dataset.df ))\n", + "# Split dataset into train and val\n", + "val_size = int(len(dataset) * val_split)\n", + "train_size = len(dataset) - val_size\n", + "torch.manual_seed(42) # For reproducibility\n", + "train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n", + "# Set the validation dataset to use the same crops\n", + "val_dataset = copy.deepcopy(val_dataset)\n", + "val_dataset.dataset.validation = True\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6af0f3a2", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", + "if cosine_annealing:\n", + " sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs,eta_min=lr*1e-6)\n", + "else:\n", + " sched = None\n", + " \n", + "vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), \n", + " feature_layers=[14], \n", + " activation=nn.ReLU\n", + " )\n", + "vfe = vfe.to(device)\n", + "\n", + "loss_fn = ShadowAwareLoss(\n", + " alpha=run_config['alpha'],\n", + " beta=run_config['beta'],\n", + " l1_weight=run_config['l1_weight'],\n", + " ssim_weight=run_config['ssim_weight'],\n", + " tv_weight=run_config['tv_weight'],\n", + " vgg_loss_weight=run_config['vgg_loss_weight'],\n", + " percept_loss_weight=run_config['percept_loss_weight'],\n", + " apply_gamma_fn=apply_gamma_torch,\n", + " vgg_feature_extractor=vfe,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bc86a06", + "metadata": {}, + "outputs": [], + "source": [ + "with mlflow.start_run(run_name=run_config['run_name']) as run:\n", + " mlflow.log_params(params)\n", + " for epoch in range(num_epochs):\n", + " train_one_epoch(epoch, model, optimizer, train_loader, device, loss_fn, clipping, \n", + " log_interval = 10, sleep=0.0, rggb=rggb, max_batches=0)\n", + " if cosine_annealing:\n", + " sched.step()\n", + " \n", + " mlflow.pytorch.log_model(\n", + " pytorch_model=model,\n", + " name=run_config['run_path'],\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a775d31", + "metadata": {}, + "outputs": [], + "source": [ + "run.info.run_id" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/1_train_model_raw_demosaicing.ipynb b/1_train_model_raw_demosaicing.ipynb new file mode 100644 index 0000000..9c95d38 --- /dev/null +++ b/1_train_model_raw_demosaicing.ipynb @@ -0,0 +1,302 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6351e77", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader, random_split\n", + "import torch.nn as nn\n", + "import torch\n", + "import copy\n", + "import mlflow\n", + "import mlflow.pytorch\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2043dc8e", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.DemosaicingDataset import DemosaicingDataset\n", + "from src.training.losses.ShadowAwareLoss import ShadowAwareLoss\n", + "from src.training.VGGFeatureExtractor import VGGFeatureExtractor\n", + "from src.training.train_loop import train_one_epoch, visualize\n", + "from src.training.utils import apply_gamma_torch\n", + "from src.training.load_config import load_config\n", + "from src.Restorer.Cond_NAF_demosaic import make_full_model_RGGB_Demosaicing\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0464c98", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config('config_demosaicing.yaml')\n", + "dataset_path = Path(run_config['cropped_raw_subdir'])\n", + "align_csv = dataset_path / run_config['secondary_align_csv']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba20b866", + "metadata": {}, + "outputs": [], + "source": [ + "device=run_config['device']\n", + "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", + "\n", + "num_epochs = run_config['num_epochs_pretraining']\n", + "cosine_annealing = run_config['cosine_annealing']\n", + "\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "mlflow_path = run_config['mlflow_path']\n", + "colorspace = run_config['colorspace']\n", + "iso_range = run_config['iso_range']\n", + "\n", + "rggb = True\n", + "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", + "mlflow.set_experiment(experiment)\n", + "\n", + "params = {**run_config}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48e645a0", + "metadata": {}, + "outputs": [], + "source": [ + "mlflow_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9a26124", + "metadata": {}, + "outputs": [], + "source": [ + "model_params = run_config['model_params']\n", + "rggb = model_params['rggb']\n", + "\n", + "model = make_full_model_RGGB_Demosaicing(model_params, model_name=None)\n", + "model = model.to(device)\n", + "\n", + "params = {**run_config, **model_params}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f16fa7", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = DemosaicingDataset(dataset_path, align_csv, colorspace, output_crop_size=crop_size, downsample_factor=4)\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('crw')]\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('dng_bayer')]\n", + "dataset.df = dataset.df[(dataset.df.iso >= iso_range[0]) & (dataset.df.iso <= iso_range[1])]\n", + "print(len(dataset.df ))\n", + "# Split dataset into train and val\n", + "val_size = int(len(dataset) * val_split)\n", + "train_size = len(dataset) - val_size\n", + "torch.manual_seed(42) # For reproducibility\n", + "train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n", + "# Set the validation dataset to use the same crops\n", + "val_dataset = copy.deepcopy(val_dataset)\n", + "val_dataset.dataset.validation = True\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6af0f3a2", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", + "if cosine_annealing:\n", + " sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs,eta_min=lr*1e-6)\n", + "else:\n", + " sched = None\n", + " \n", + "vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), \n", + " feature_layers=[14], \n", + " activation=nn.ReLU\n", + " )\n", + "vfe = vfe.to(device)\n", + "\n", + "loss_fn = ShadowAwareLoss(\n", + " alpha=run_config['alpha'],\n", + " beta=run_config['beta'],\n", + " l1_weight=run_config['l1_weight'],\n", + " ssim_weight=run_config['ssim_weight'],\n", + " tv_weight=run_config['tv_weight'],\n", + " vgg_loss_weight=run_config['vgg_loss_weight'],\n", + " percept_loss_weight=run_config['percept_loss_weight'],\n", + " apply_gamma_fn=apply_gamma_torch,\n", + " vgg_feature_extractor=vfe,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58244dcc", + "metadata": {}, + "outputs": [], + "source": [ + "loss_fn = nn.L1Loss()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4616a32", + "metadata": {}, + "outputs": [], + "source": [ + "from time import perf_counter\n", + "import time\n", + "from tqdm import tqdm\n", + "import torch\n", + "import torch.nn as nn\n", + "from src.training.utils import apply_gamma_torch\n", + "import mlflow\n", + "\n", + "def make_conditioning(conditioning, device):\n", + " B = conditioning.shape[0]\n", + " conditioning_extended = torch.zeros(B, 1).to(device)\n", + " conditioning_extended[:, 0] = conditioning[:, 0]\n", + " return conditioning_extended\n", + "\n", + "\n", + "def train_one_epoch(epoch, _model, _optimizer, _loader, _device, _loss_func, _clipping, \n", + " log_interval = 10, sleep=0.0, rggb=False,\n", + " max_batches=0):\n", + " _model.train()\n", + " total_loss, n_images, total_l1_loss = 0.0, 0, 0.0\n", + " start = perf_counter()\n", + " pbar = tqdm(enumerate(_loader), total=len(_loader), desc=f\"Train Epoch {epoch}\")\n", + "\n", + " for batch_idx, (output) in pbar:\n", + " conditioning = output['conditioning'].float().to(_device)\n", + " gt = output['ground_truth'].float().to(_device)\n", + " input = output['cfa_sparse'].float().to(_device)\n", + " if rggb:\n", + " input = output['cfa_rggb'].float().to(_device)\n", + " conditioning = make_conditioning(conditioning, _device)\n", + "\n", + " _optimizer.zero_grad(set_to_none=True)\n", + " pred = _model(input, conditioning) \n", + "\n", + " loss = _loss_func(pred, gt)\n", + " _optimizer.zero_grad(set_to_none=True)\n", + " loss.backward()\n", + " torch.nn.utils.clip_grad_norm_(_model.parameters(), _clipping)\n", + " _optimizer.step()\n", + "\n", + " total_loss += float(loss.detach().cpu())\n", + " n_images += gt.shape[0]\n", + "\n", + " # Testing final image quality\n", + " final_image_loss = float(nn.functional.l1_loss(pred, gt).detach().cpu())\n", + " total_l1_loss += final_image_loss\n", + " del loss, pred, final_image_loss\n", + " torch.mps.empty_cache() \n", + "\n", + " if (batch_idx + 1) % log_interval == 0:\n", + " pbar.set_postfix({\"loss\": f\"{total_loss/n_images:.4f}\"})\n", + "\n", + " if (max_batches > 0) and (batch_idx+1 > max_batches): break\n", + " time.sleep(sleep)\n", + "\n", + " train_time = perf_counter()-start\n", + " print(f\"[Epoch {epoch}] \"\n", + " f\"Train loss: {total_loss/n_images:.6f} \"\n", + " f\"L1 loss: {total_l1_loss/n_images:.6f} \"\n", + " f\"Time: {train_time:.1f}s \"\n", + " f\"Images: {n_images}\")\n", + " mlflow.log_metric(\"train_loss\", total_loss/n_images, step=epoch)\n", + " mlflow.log_metric(\"l1_loss\", total_l1_loss/n_images, step=epoch)\n", + " mlflow.log_metric(\"epoch_duration_s\", train_time, step=epoch)\n", + " mlflow.log_metric(\"learning_rate\", _optimizer.param_groups[0]['lr'], step=epoch)\n", + "\n", + " return total_loss / max(1, n_images), perf_counter()-start" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bc86a06", + "metadata": {}, + "outputs": [], + "source": [ + "with mlflow.start_run(run_name=run_config['run_name']) as run:\n", + " mlflow.log_params(params)\n", + " for epoch in range(num_epochs):\n", + " train_one_epoch(epoch, model, optimizer, train_loader, device, loss_fn, clipping, \n", + " log_interval = 10, sleep=0.0, rggb=rggb, max_batches=0)\n", + " if cosine_annealing:\n", + " sched.step()\n", + " \n", + " mlflow.pytorch.log_model(\n", + " pytorch_model=model,\n", + " name=run_config['run_path'],\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a775d31", + "metadata": {}, + "outputs": [], + "source": [ + "run.info.run_id" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/1_validate_model.ipynb b/1_validate_model.ipynb new file mode 100644 index 0000000..b1400c0 --- /dev/null +++ b/1_validate_model.ipynb @@ -0,0 +1,567 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6351e77", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader, random_split\n", + "import torch.nn as nn\n", + "import torch\n", + "import copy\n", + "import mlflow\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2043dc8e", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.SmallRawDataset import SmallRawDataset\n", + "\n", + "from src.training.losses.ShadowAwareLoss import ShadowAwareLoss\n", + "from src.training.VGGFeatureExtractor import VGGFeatureExtractor\n", + "from src.training.utils import apply_gamma_torch\n", + "from src.training.train_loop import visualize\n", + "from src.training.load_config import load_config\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b14af214", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "dataset_path = Path(run_config['jpeg_output_subdir'])\n", + "align_csv = dataset_path / run_config['secondary_align_csv']\n", + "device=run_config['device']\n", + "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", + "\n", + "num_epochs = run_config['num_epochs_pretraining']\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "model_params = run_config['model_params']\n", + "rggb = model_params['rggb']\n", + "mlflow_path = run_config['mlflow_path']\n", + "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", + "mlflow.set_experiment(experiment)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba20b866", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "RUN_ID = \"b0664f324e9444d3b3a5277d513d3642\" \n", + "ARTIFACT_PATH = run_config['run_path']\n", + "\n", + "model_uri = f\"runs:/{RUN_ID}/{ARTIFACT_PATH}\"\n", + "\n", + "try:\n", + " model = mlflow.pytorch.load_model(model_uri)\n", + " model.eval()\n", + " print(f\"Model successfully loaded from MLflow URI: {model_uri}\")\n", + " \n", + "\n", + "except Exception as e:\n", + " print(f\"Error loading model from MLflow: {e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f16fa7", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = SmallRawDataset(dataset_path, align_csv, crop_size=crop_size)\n", + "\n", + "# Split dataset into train and val\n", + "val_size = int(len(dataset) * val_split)\n", + "train_size = len(dataset) - val_size\n", + "torch.manual_seed(42) # For reproducibility\n", + "train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n", + "# Set the validation dataset to use the same crops\n", + "val_dataset = copy.deepcopy(val_dataset)\n", + "val_dataset.dataset.validation = True\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6af0f3a2", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", + "\n", + "vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), \n", + " feature_layers=[14], \n", + " activation=nn.ReLU\n", + " )\n", + "vfe = vfe.to(device)\n", + "\n", + "loss_fn = ShadowAwareLoss(\n", + " alpha=run_config['alpha'],\n", + " beta=run_config['beta'],\n", + " l1_weight=run_config['l1_weight'],\n", + " ssim_weight=run_config['ssim_weight'],\n", + " tv_weight=run_config['tv_weight'],\n", + " vgg_loss_weight=run_config['vgg_loss_weight'],\n", + " apply_gamma_fn=apply_gamma_torch,\n", + " vgg_feature_extractor=vfe,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51ad2666", + "metadata": {}, + "outputs": [], + "source": [ + "# with mlflow.start_run(run_id=existing_run_id) as run:\n", + "# print(f\"Re-opened run: {run.info.run_name}\")\n", + " \n", + "# # Log your new validation metric\n", + "# mlflow.log_metric(\"final_validation_accuracy\", new_validation_metric)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae01182c", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "subset_indices = np.array(val_dataset.indices) # indices in the original dataset\n", + "mask = val_dataset.dataset.df.iso.values[subset_indices] == 65535\n", + "matching_indices_in_subset = np.nonzero(mask)[0]\n", + "matching_indices_in_subset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2234a52d", + "metadata": {}, + "outputs": [], + "source": [ + "visualize(matching_indices_in_subset, model, val_dataset, device, loss_fn, rggb=rggb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "641d864b", + "metadata": {}, + "outputs": [], + "source": [ + "Train loss: 0.056005 Final image val loss: 0.009073 Time: 45.7s Images: 25\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "256f0de3", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "import mlflow\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from time import perf_counter\n", + "\n", + "\n", + "def make_conditioning(conditioning, device):\n", + " B = conditioning.shape[0]\n", + " conditioning_extended = torch.zeros(B, 1).to(device)\n", + " conditioning_extended[:, 0] = conditioning[:, 0]\n", + " return conditioning_extended\n", + "\n", + "\n", + "def validate_model(\n", + " _model, \n", + " val_dataset, \n", + " _device, \n", + " _loss_func, epoch,\n", + " iso_values=(100, 400, 1600, 65535), \n", + " n_examples=3, \n", + " rggb=False,\n", + " artifact_dir=\"val_examples\"\n", + "):\n", + " \"\"\"\n", + " Run validation over different ISO values and log metrics + sample images to MLflow.\n", + "\n", + " Args:\n", + " _model: torch model\n", + " val_dataset: dataset with `df.iso` values accessible\n", + " _device: torch device\n", + " _loss_func: loss function\n", + " iso_values: list/tuple of ISO values to evaluate\n", + " n_examples: number of images to log per ISO\n", + " rggb: whether to use rggb input\n", + " artifact_dir: subdirectory to store example images for MLflow\n", + " \"\"\"\n", + "\n", + " _model.eval()\n", + " os.makedirs(artifact_dir, exist_ok=True)\n", + "\n", + " all_metrics = {}\n", + "\n", + " for iso in iso_values:\n", + " # Select subset indices with this ISO\n", + " subset_indices = np.array(val_dataset.indices)\n", + " mask = val_dataset.dataset.df.iso.values[subset_indices] == iso\n", + " matching_indices_in_subset = np.nonzero(mask)[0]\n", + "\n", + " if len(matching_indices_in_subset) == 0:\n", + " print(f\"No validation samples for ISO {iso}\")\n", + " continue\n", + "\n", + "\n", + " idxs = matching_indices_in_subset[:n_examples]\n", + "\n", + " total_loss, final_img_loss, duration = visualize(\n", + " idxs, _model, val_dataset, _device, _loss_func, rggb=rggb\n", + " )\n", + "\n", + "\n", + " # Save and log a few example denoised images\n", + " for i, idx in enumerate(idxs):\n", + " row = val_dataset[idx]\n", + " noisy = row['noisy'].unsqueeze(0).float().to(_device)\n", + " conditioning = make_conditioning(row['conditioning'].float().unsqueeze(0).to(_device), _device)\n", + " gt = row['aligned'].unsqueeze(0).float().to(_device)\n", + " input = row['rggb' if rggb else 'sparse'].unsqueeze(0).float().to(_device)\n", + "\n", + " with torch.no_grad():\n", + " with torch.autocast(device_type=\"mps\", dtype=torch.bfloat16):\n", + " pred = _model(input, conditioning, noisy)\n", + "\n", + " pred_img = apply_gamma_torch(pred[0].cpu().permute(1, 2, 0))\n", + " noisy_img = apply_gamma_torch(noisy[0].cpu().permute(1, 2, 0))\n", + " gt_img = apply_gamma_torch(gt[0].cpu().permute(1, 2, 0))\n", + "\n", + " fig, axs = plt.subplots(1, 3, figsize=(15, 5))\n", + " axs[0].imshow(noisy_img)\n", + " axs[0].set_title(f\"Noisy (ISO {iso})\")\n", + " axs[1].imshow(pred_img)\n", + " axs[1].set_title(\"Denoised\")\n", + " axs[2].imshow(gt_img)\n", + " axs[2].set_title(\"Ground Truth\")\n", + " for ax in axs: ax.axis('off')\n", + "\n", + " img_path = os.path.join(artifact_dir, f\"iso{iso}_example{i}.png\")\n", + " plt.savefig(img_path, bbox_inches=\"tight\")\n", + " plt.close(fig)\n", + "\n", + " mlflow.log_artifact(img_path, artifact_path=f\"{artifact_dir}/ISO_{iso}\")\n", + "\n", + " # Log metrics per ISO\n", + " mlflow.log_metrics({\n", + " f\"val_loss_ISO_{iso}\": total_loss,\n", + " f\"final_img_loss_ISO_{iso}\": final_img_loss,\n", + " f\"val_duration_ISO_{iso}\": duration\n", + " }, step=epoch)\n", + "\n", + " # Also log a summary table\n", + " print(\"\\nValidation Summary:\")\n", + " for iso, metrics in all_metrics.items():\n", + " print(f\"ISO {iso}: val_loss={metrics['val_loss']:.6f}, \"\n", + " f\"final_img_loss={metrics['final_image_loss']:.6f}, \"\n", + " f\"time={metrics['duration']:.1f}s\")\n", + "\n", + " return all_metrics\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15ff7052", + "metadata": {}, + "outputs": [], + "source": [ + "with mlflow.start_run(run_name=run_config['run_name']) as run:\n", + " metrics = validate_model(model, val_dataset, device, loss_fn, rggb=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "db26581d", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "class SwiGLU(nn.Module):\n", + "\n", + " def __init__(self, input_dim, hidden_dim, dropout=0.1):\n", + " super().__init__()\n", + " self.w1 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1)\n", + " self.w2 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1)\n", + " self.w3 = nn.Conv2d(hidden_dim, input_dim, 1, 1, 0, 1)\n", + " \n", + " def forward(self, x):\n", + " gate = F.silu(self.w1(x)) \n", + " value = self.w2(x)\n", + " x = gate * value \n", + " \n", + " x = self.w3(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc387d19", + "metadata": {}, + "outputs": [], + "source": [ + "class ConditionedChannelAttention(nn.Module):\n", + " def __init__(self, dims, cat_dims):\n", + " super().__init__()\n", + " in_dim = dims + cat_dims\n", + " self.mlp = nn.Sequential(nn.Linear(in_dim, dims))\n", + " self.pool = nn.AdaptiveAvgPool2d(1)\n", + "\n", + " def forward(self, x, conditioning):\n", + " pool = self.pool(x)\n", + " conditioning = conditioning.unsqueeze(-1).unsqueeze(-1)\n", + " cat_channels = torch.cat([pool, conditioning], dim=1)\n", + " cat_channels = cat_channels.permute(0, 2, 3, 1)\n", + " ca = self.mlp(cat_channels).permute(0, 3, 1, 2)\n", + "\n", + " return ca" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86fdd311", + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn.functional as F\n", + "\n", + "class SwiGLU(nn.Module):\n", + "\n", + " def __init__(self, input_dim, hidden_dim, dropout=0.1):\n", + " super().__init__()\n", + " self.w1 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1)\n", + " self.w2 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1)\n", + " self.w3 = nn.Conv2d(hidden_dim, input_dim, 1, 1, 0, 1)\n", + " \n", + " def forward(self, x):\n", + " gate = F.silu(self.w1(x)) \n", + " value = self.w2(x)\n", + " x = gate * value \n", + " \n", + " x = self.w3(x)\n", + " return x\n", + " \n", + "class AttnBlock(nn.Module):\n", + " def __init__(self, c, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0):\n", + " super().__init__()\n", + " \n", + " self.dw = nn.Conv2d(\n", + " in_channels=c,\n", + " out_channels=c,\n", + " kernel_size=3,\n", + " padding=1,\n", + " stride=1,\n", + " groups=c,\n", + " bias=True,\n", + " )\n", + "\n", + " self.sca = ConditionedChannelAttention(c, cond_chans)\n", + "\n", + " self.norm = nn.GroupNorm(1, c)\n", + " \n", + " self.swiglu = SwiGLU(c, int(c * FFN_Expand))\n", + " self.alpha = nn.Parameter(torch.zeros(1, c, 1, 1))\n", + " self.beta = nn.Parameter(torch.zeros(1, c, 1, 1))\n", + "\n", + "\n", + " def forward(self, input):\n", + " inp = input[0]\n", + " cond = input[1]\n", + "\n", + " x = self.dw(inp)\n", + " x = self.sca(x, cond) * x\n", + " y = self.norm(inp + self.alpha * x )\n", + "\n", + "\n", + " x = self.swiglu(y)\n", + " x = y + self.beta * x\n", + " return (x, cond)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "990a5993", + "metadata": {}, + "outputs": [], + "source": [ + "block = AttnBlock(64, 2, cond_chans=1)\n", + "block((torch.rand(1, 64, 32, 32), torch.rand(1, 1)))[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b460abf", + "metadata": {}, + "outputs": [], + "source": [ + "class NAFBlock0(nn.Module):\n", + " def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0):\n", + " super().__init__()\n", + " dw_channel = c * DW_Expand\n", + " self.conv1 = nn.Conv2d(\n", + " in_channels=c,\n", + " out_channels=dw_channel,\n", + " kernel_size=1,\n", + " padding=0,\n", + " stride=1,\n", + " groups=1,\n", + " bias=True,\n", + " )\n", + " self.conv2 = nn.Conv2d(\n", + " in_channels=dw_channel,\n", + " out_channels=dw_channel,\n", + " kernel_size=3,\n", + " padding=1,\n", + " stride=1,\n", + " groups=dw_channel,\n", + " bias=True,\n", + " )\n", + " self.conv3 = nn.Conv2d(\n", + " in_channels=dw_channel // 2,\n", + " out_channels=c,\n", + " kernel_size=1,\n", + " padding=0,\n", + " stride=1,\n", + " groups=1,\n", + " bias=True,\n", + " )\n", + "\n", + " # Simplified Channel Attention\n", + " self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans)\n", + "\n", + " # SimpleGate\n", + " self.sg = SimpleGate()\n", + "\n", + " ffn_channel = FFN_Expand * c\n", + " self.conv4 = nn.Conv2d(\n", + " in_channels=c,\n", + " out_channels=ffn_channel,\n", + " kernel_size=1,\n", + " padding=0,\n", + " stride=1,\n", + " groups=1,\n", + " bias=True,\n", + " )\n", + " self.conv5 = nn.Conv2d(\n", + " in_channels=ffn_channel // 2,\n", + " out_channels=c,\n", + " kernel_size=1,\n", + " padding=0,\n", + " stride=1,\n", + " groups=1,\n", + " bias=True,\n", + " )\n", + "\n", + " # self.grn = GRN(ffn_channel // 2)\n", + "\n", + " self.norm1 = LayerNorm2d(c)\n", + " self.norm2 = LayerNorm2d(c)\n", + "\n", + " self.dropout1 = (\n", + " nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity()\n", + " )\n", + " self.dropout2 = (\n", + " nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity()\n", + " )\n", + "\n", + " self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)\n", + " self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)\n", + "\n", + " def forward(self, input):\n", + " inp = input[0]\n", + " cond = input[1]\n", + "\n", + " x = inp\n", + "\n", + " x = self.norm1(x)\n", + " x = self.conv1(x)\n", + " x = self.conv2(x)\n", + " x = self.sg(x)\n", + " x = x * self.sca(x, cond)\n", + " x = self.conv3(x)\n", + "\n", + " x = self.dropout1(x)\n", + "\n", + " y = inp + x * self.beta\n", + "\n", + " # Channel Mixing\n", + " x = self.conv4(self.norm2(y))\n", + " x = self.sg(x)\n", + " # x = self.grn(x)\n", + " x = self.conv5(x)\n", + "\n", + " x = self.dropout2(x)\n", + "\n", + " return (y + x * self.gamma, cond)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/1_validated_raw_demosaicing.ipynb b/1_validated_raw_demosaicing.ipynb new file mode 100644 index 0000000..3698218 --- /dev/null +++ b/1_validated_raw_demosaicing.ipynb @@ -0,0 +1,377 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6351e77", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader, random_split\n", + "import torch.nn as nn\n", + "import torch\n", + "import copy\n", + "import mlflow\n", + "import mlflow.pytorch\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2043dc8e", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.DemosaicingDataset import DemosaicingDataset\n", + "from src.training.losses.ShadowAwareLoss import ShadowAwareLoss\n", + "from src.training.VGGFeatureExtractor import VGGFeatureExtractor\n", + "from src.training.train_loop import train_one_epoch, visualize\n", + "from src.training.utils import apply_gamma_torch\n", + "from src.training.load_config import load_config\n", + "from src.Restorer.Cond_NAF import make_full_model_RGGB_Demosaicing\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0464c98", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config('config_demosaicing.yaml')\n", + "dataset_path = Path(run_config['cropped_raw_subdir'])\n", + "align_csv = dataset_path / run_config['secondary_align_csv']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba20b866", + "metadata": {}, + "outputs": [], + "source": [ + "device=run_config['device']\n", + "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", + "\n", + "num_epochs = run_config['num_epochs_pretraining']\n", + "cosine_annealing = run_config['cosine_annealing']\n", + "\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "mlflow_path = run_config['mlflow_path']\n", + "colorspace = run_config['colorspace']\n", + "iso_range = run_config['iso_range']\n", + "\n", + "rggb = True\n", + "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", + "mlflow.set_experiment(experiment)\n", + "\n", + "params = {**run_config}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9a26124", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "RUN_ID = \"01a8dc61c70c47f99e3b846f21c11fb7\" \n", + "ARTIFACT_PATH = run_config['run_path']\n", + "\n", + "model_uri = f\"runs:/{RUN_ID}/{ARTIFACT_PATH}\"\n", + "\n", + "try:\n", + " model = mlflow.pytorch.load_model(model_uri)\n", + " model.eval()\n", + " print(f\"Model successfully loaded from MLflow URI: {model_uri}\")\n", + " \n", + "\n", + "except Exception as e:\n", + " print(f\"Error loading model from MLflow: {e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "deaae86d", + "metadata": {}, + "outputs": [], + "source": [ + "from src.Restorer.debayering import Debayer5x5\n", + "debayer = Debayer5x5()\n", + "debayer = debayer.to(device)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f16fa7", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = DemosaicingDataset(dataset_path, align_csv, colorspace, output_crop_size=crop_size, downsample_factor=4)\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('crw')]\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('dng_bayer')]\n", + "dataset.df = dataset.df[(dataset.df.iso >= iso_range[0]) & (dataset.df.iso <= iso_range[1])]\n", + "print(len(dataset.df ))\n", + "# Split dataset into train and val\n", + "val_size = int(len(dataset) * val_split)\n", + "train_size = len(dataset) - val_size\n", + "torch.manual_seed(42) # For reproducibility\n", + "train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n", + "# Set the validation dataset to use the same crops\n", + "val_dataset = copy.deepcopy(val_dataset)\n", + "val_dataset.dataset.validation = True\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6af0f3a2", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", + "if cosine_annealing:\n", + " sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs,eta_min=lr*1e-6)\n", + "else:\n", + " sched = None\n", + " \n", + "vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), \n", + " feature_layers=[14], \n", + " activation=nn.ReLU\n", + " )\n", + "vfe = vfe.to(device)\n", + "\n", + "loss_fn = ShadowAwareLoss(\n", + " alpha=run_config['alpha'],\n", + " beta=run_config['beta'],\n", + " l1_weight=run_config['l1_weight'],\n", + " ssim_weight=run_config['ssim_weight'],\n", + " tv_weight=run_config['tv_weight'],\n", + " vgg_loss_weight=run_config['vgg_loss_weight'],\n", + " percept_loss_weight=run_config['percept_loss_weight'],\n", + " apply_gamma_fn=apply_gamma_torch,\n", + " vgg_feature_extractor=vfe,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4616a32", + "metadata": {}, + "outputs": [], + "source": [ + "from time import perf_counter\n", + "import time\n", + "from tqdm import tqdm\n", + "import torch\n", + "import torch.nn as nn\n", + "from src.training.utils import apply_gamma_torch\n", + "import mlflow\n", + "\n", + "def make_conditioning(conditioning, device):\n", + " B = conditioning.shape[0]\n", + " conditioning_extended = torch.zeros(B, 1).to(device)\n", + " conditioning_extended[:, 0] = conditioning[:, 0]\n", + " return conditioning_extended\n", + "\n", + "\n", + "from colour_demosaicing import (\n", + " mosaicing_CFA_Bayer,\n", + " demosaicing_CFA_Bayer_Menon2007\n", + ")\n", + "def visualize(idxs, _model, dataset, _device, _loss_func, rggb=False):\n", + " import matplotlib.pyplot as plt\n", + " _model.train()\n", + " total_loss, total_loss_malhevar, n_images, total_final_image_loss = 0.0, 0.0, 0, 0.0\n", + " start = perf_counter()\n", + "\n", + " for idx in idxs:\n", + " row = dataset[idx]\n", + " conditioning = row['conditioning'].unsqueeze(0).float().to(_device)\n", + " gt = row['ground_truth'].unsqueeze(0).float().to(_device)\n", + " sparse = row['cfa_sparse'].unsqueeze(0).float().to(_device)\n", + " rggb_values = row['cfa_rggb'].unsqueeze(0).float().to(_device)\n", + " bayer = nn.functional.pixel_shuffle(rggb_values, 2)\n", + "\n", + " debayered = debayer(bayer) \n", + " input = sparse\n", + " if rggb:\n", + " input = rggb_values\n", + "\n", + " conditioning = make_conditioning(conditioning, _device)\n", + " \n", + " with torch.no_grad():\n", + " with torch.autocast(device_type=\"mps\", dtype=torch.bfloat16):\n", + " pred = _model(input, conditioning)\n", + " # pred = debayered \n", + " loss = _loss_func(pred, gt)\n", + " loss_mal_he_var = _loss_func(debayered, gt)\n", + "\n", + " total_loss_malhevar += float(loss_mal_he_var.detach().cpu())\n", + " total_loss += float(loss.detach().cpu())\n", + " n_images += gt.shape[0]\n", + "\n", + " # Testing final image quality\n", + " final_image_loss = nn.functional.l1_loss(pred, gt)\n", + " total_final_image_loss += final_image_loss.item()\n", + "\n", + " plt.subplots(2, 3, figsize=(30, 15))\n", + "\n", + " plt.subplot(2, 3, 1)\n", + " plt.title('pred')\n", + " pred = apply_gamma_torch(pred[0].cpu().permute(1, 2, 0))\n", + " plt.imshow(pred)\n", + "\n", + " plt.subplot(2, 3, 2)\n", + " # plt.title('Menon')\n", + " # bayer = sparse[0].sum(axis=0)\n", + " # bayer = apply_gamma_torch(bayer).cpu().numpy()\n", + " # trad_demosaiced = demosaicing_CFA_Bayer_Menon2007(bayer)\n", + " # plt.imshow(trad_demosaiced)\n", + " plt.title('Mal He Var')\n", + " debayered_numpy = apply_gamma_torch(debayered)[0].permute(1, 2, 0).cpu().numpy()\n", + " plt.imshow(debayered_numpy)\n", + "\n", + " plt.subplot(2, 3, 3)\n", + " plt.title('gt')\n", + "\n", + " gt = apply_gamma_torch(gt[0].cpu().permute(1, 2, 0))\n", + " plt.imshow(gt)\n", + "\n", + " plt.subplot(2, 3, 4)\n", + " plt.imshow(pred - gt + 0.5)\n", + "\n", + "\n", + " plt.subplot(2, 3, 5)\n", + " plt.imshow(debayered_numpy - pred.cpu().numpy() + 0.5)\n", + "\n", + " plt.subplot(2, 3, 6)\n", + " plt.imshow(debayered_numpy - gt.cpu().numpy() + 0.5)\n", + " plt.show()\n", + " plt.clf()\n", + "\n", + " n_images = len(idxs)\n", + " print(\n", + " f\"Train loss: {total_loss/n_images:.6f} \"\n", + " f\"Train loss (MHV): {total_loss_malhevar/n_images:.6f} \"\n", + " f\"Final image val loss: {total_final_image_loss/n_images:.6f} \"\n", + " f\"Time: {perf_counter()-start:.1f}s \"\n", + " f\"Images: {n_images}\")\n", + "\n", + " return total_loss / max(1, n_images), total_final_image_loss / max(1, n_images), perf_counter()-start" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5b87533", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "subset_indices = np.array(val_dataset.indices) # indices in the original dataset\n", + "mask = val_dataset.dataset.df.iso.values[subset_indices] == 200\n", + "matching_indices_in_subset = np.nonzero(mask)[0]\n", + "matching_indices_in_subset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cde66163", + "metadata": {}, + "outputs": [], + "source": [ + "visualize(matching_indices_in_subset, model, val_dataset, device, loss_fn, rggb=rggb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f092eecc", + "metadata": {}, + "outputs": [], + "source": [ + "from src.Restorer.debayering import Debayer5x5" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35fd5ed9", + "metadata": {}, + "outputs": [], + "source": [ + "output = dataset[0]\n", + "\n", + "rggb_values = output['cfa_rggb'].unsqueeze(0).float()\n", + "bayer = nn.functional.pixel_shuffle(rggb_values, 2).to('mps')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfa7d2e9", + "metadata": {}, + "outputs": [], + "source": [ + "im = debayer(bayer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12e6a557", + "metadata": {}, + "outputs": [], + "source": [ + "plt.imshow(im[0].cpu().permute(1, 2, 0))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e90dd490", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/1_visualize_alignment.ipynb b/1_visualize_alignment.ipynb new file mode 100644 index 0000000..04bbf73 --- /dev/null +++ b/1_visualize_alignment.ipynb @@ -0,0 +1,279 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6351e77", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader, random_split\n", + "import copy\n", + "from pathlib import Path\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2043dc8e", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.SmallRawDatasetNumpy import SmallRawDatasetNumpy\n", + "from src.training.load_config import load_config\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b14af214", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "dataset_path = Path(run_config['cropped_raw_subdir'])\n", + "align_csv = dataset_path / run_config['secondary_align_csv']\n", + "device=run_config['device']\n", + "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", + "\n", + "num_epochs = run_config['num_epochs_pretraining']\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "model_params = run_config['model_params']\n", + "rggb = model_params['rggb']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f16fa7", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = SmallRawDatasetNumpy(dataset_path, align_csv, crop_size=crop_size, validation=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51ad2666", + "metadata": {}, + "outputs": [], + "source": [ + "# with mlflow.start_run(run_id=existing_run_id) as run:\n", + "# print(f\"Re-opened run: {run.info.run_name}\")\n", + " \n", + "# # Log your new validation metric\n", + "# mlflow.log_metric(\"final_validation_accuracy\", new_validation_metric)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae01182c", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4aa24966", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "from scipy.signal import correlate2d\n", + "from scipy.ndimage import gaussian_filter\n", + "from skimage.metrics import structural_similarity as ssim\n", + "\n", + "# --- Metric functions ---\n", + "\n", + "def structure_score_autocorr(img_np):\n", + " img = img_np - img_np.mean()\n", + " corr = correlate2d(img, img, mode='full')\n", + " corr /= corr.max() + 1e-8\n", + " center = np.array(corr.shape) // 2\n", + " corr[center[0], center[1]] = 0\n", + " return float(np.mean(corr))\n", + "\n", + "def fourier_entropy(img_np):\n", + " f = np.fft.fftshift(np.fft.fft2(img_np))\n", + " mag = np.abs(f)\n", + " mag = mag / (mag.sum() + 1e-8)\n", + " return float(-np.sum(mag * np.log(mag + 1e-12)))\n", + "\n", + "def laplacian_variance(img_np):\n", + " import cv2\n", + " lap = cv2.Laplacian(img_np, cv2.CV_64F)\n", + " return float(lap.var())\n", + "\n", + "def self_ssim(img_np, sigma=1.0):\n", + " blurred = gaussian_filter(img_np, sigma)\n", + " return float(ssim(img_np, blurred, data_range=img_np.max() - img_np.min()))\n", + "\n", + "# --- Visualization and analysis function ---\n", + "\n", + "def analyze_dataset_structure(dataset, indices, out_dir=\"structure_analysis\"):\n", + " import os\n", + " os.makedirs(out_dir, exist_ok=True)\n", + " records = []\n", + "\n", + " for idx in indices:\n", + " row = dataset[idx]\n", + " noisy = row['noisy']\n", + " gt = row['aligned']\n", + "\n", + " # Ensure CPU numpy grayscale (if multi-channel)\n", + " def to_gray(x):\n", + " if isinstance(x, torch.Tensor):\n", + " x = x.detach().cpu().numpy()\n", + " if x.ndim == 3 and x.shape[0] in [1,3]:\n", + " x = np.moveaxis(x, 0, -1)\n", + " if x.ndim == 3:\n", + " x = np.mean(x, axis=-1)\n", + " return np.clip(x, 0, 1)\n", + "\n", + " noisy_np = to_gray(noisy)\n", + " gt_np = to_gray(gt)\n", + " diff_np = np.abs(noisy_np - gt_np)\n", + "\n", + " # Compute metrics\n", + " metrics = {\n", + " \"idx\": idx,\n", + " \"fourier_entropy_noisy\": fourier_entropy(noisy_np),\n", + " \"lap_var_noisy\": laplacian_variance(noisy_np),\n", + " \"self_ssim_noisy\": self_ssim(noisy_np),\n", + " \"autocorr_noisy\": structure_score_autocorr(noisy_np),\n", + " \"fourier_entropy_diff\": fourier_entropy(diff_np),\n", + " \"lap_var_diff\": laplacian_variance(diff_np),\n", + " }\n", + " records.append(metrics)\n", + "\n", + " # Plot\n", + " fig, axes = plt.subplots(1, 3, figsize=(12, 4))\n", + " for ax, im, title in zip(axes, [noisy_np, gt_np, diff_np],\n", + " ['Noisy', 'Aligned', 'Difference']):\n", + " ax.imshow(im, cmap='gray')\n", + " ax.set_title(title)\n", + " ax.axis('off')\n", + "\n", + " txt = \"\\n\".join([\n", + " f\"Entropy: {metrics['fourier_entropy_noisy']:.3f}\",\n", + " f\"LapVar: {metrics['lap_var_noisy']:.3f}\",\n", + " f\"SelfSSIM: {metrics['self_ssim_noisy']:.3f}\",\n", + " f\"AutoCorr: {metrics['autocorr_noisy']:.3f}\",\n", + " ])\n", + " plt.gcf().text(0.75, 0.5, txt, fontsize=10, va='center')\n", + " plt.tight_layout()\n", + " plt.savefig(f\"{out_dir}/example_{idx:04d}.png\", dpi=150)\n", + " plt.close(fig)\n", + "\n", + " df = pd.DataFrame.from_records(records)\n", + " df.to_csv(f\"{out_dir}/structure_metrics.csv\", index=False)\n", + " print(f\"Saved {len(df)} examples and metrics to {out_dir}\")\n", + " return df\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56f5e484", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "from scipy.signal import correlate2d\n", + "from scipy.ndimage import gaussian_filter\n", + "from skimage.metrics import structural_similarity as ssim\n", + "\n", + "def analyze_dataset_structure(dataset, indices, out_dir=\"structure_analysis\"):\n", + " import os\n", + " os.makedirs(out_dir, exist_ok=True)\n", + " records = []\n", + "\n", + " for idx in indices:\n", + " row = dataset[idx]\n", + " noisy = row['noisy'].permute(1, 2, 0)**(1/2.2)\n", + " gt = row['aligned'].permute(1, 2, 0)**(1/2.2)\n", + " diff = noisy-gt +0.5\n", + " # Plot\n", + " fig, axes = plt.subplots(1, 3, figsize=(12, 4))\n", + " for ax, im, title in zip(axes, [noisy, gt, diff],\n", + " ['Noisy', 'Aligned', 'Difference']):\n", + " ax.imshow(im)\n", + " ax.set_title(title)\n", + " ax.axis('off')\n", + "\n", + " plt.tight_layout()\n", + " plt.show()\n", + " plt.close(fig)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59269994", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "mask = dataset.df.iso.values == 400\n", + "matching_indices_in_subset = np.nonzero(mask)[0]\n", + "matching_indices_in_subset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54c87086", + "metadata": {}, + "outputs": [], + "source": [ + "analyze_dataset_structure(dataset, matching_indices_in_subset,out_dir=\"structure_analysis_no_align\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "422b379e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/2_finetune_model_raw.ipynb b/2_finetune_model_raw.ipynb new file mode 100644 index 0000000..9a7d529 --- /dev/null +++ b/2_finetune_model_raw.ipynb @@ -0,0 +1,258 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f6351e77", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader, random_split\n", + "import torch.nn as nn\n", + "import torch\n", + "import copy\n", + "import mlflow\n", + "import mlflow.pytorch\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2043dc8e", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.RawDatasetDNG import RawDatasetDNG\n", + "from src.training.losses.ShadowAwareLoss import ShadowAwareLoss\n", + "from src.training.VGGFeatureExtractor import VGGFeatureExtractor\n", + "from src.training.train_loop import train_one_epoch, visualize\n", + "from src.training.utils import apply_gamma_torch\n", + "from src.training.load_config import load_config\n", + "from src.Restorer.Cond_NAF import make_full_model_RGGB\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0464c98", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "dataset_path = Path(run_config['cropped_raw_subdir'])\n", + "align_csv = dataset_path / run_config['secondary_align_csv']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7322456f", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba20b866", + "metadata": {}, + "outputs": [], + "source": [ + "device=run_config['device']\n", + "\n", + "batch_size = run_config['batch_size']\n", + "lr = run_config['lr_base'] * batch_size\n", + "clipping = run_config['clipping']\n", + "\n", + "num_epochs = run_config['num_epochs_finetuning']\n", + "cosine_annealing = run_config['cosine_annealing']\n", + "\n", + "val_split = run_config['val_split']\n", + "crop_size = run_config['crop_size']\n", + "experiment = run_config['mlflow_experiment']\n", + "mlflow_path = run_config['mlflow_path']\n", + "colorspace = run_config['colorspace']\n", + "iso_range = run_config['iso_range']\n", + "\n", + "rggb = True\n", + "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", + "mlflow.set_experiment(experiment)\n", + "\n", + "params = {**run_config}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9a26124", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "RUN_ID = \"7d9ffb05e2c747fe93647e06ef43e51b\" \n", + "ARTIFACT_PATH = run_config['run_path']\n", + "params['base_RUN_ID'] = RUN_ID\n", + "params['base_ARTIFACT_PATH'] = ARTIFACT_PATH\n", + "\n", + "model_uri = f\"runs:/{RUN_ID}/{ARTIFACT_PATH}\"\n", + "\n", + "try:\n", + " model = mlflow.pytorch.load_model(model_uri)\n", + " model.eval()\n", + " print(f\"Model successfully loaded from MLflow URI: {model_uri}\")\n", + " \n", + "\n", + "except Exception as e:\n", + " print(f\"Error loading model from MLflow: {e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15f16fa7", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = RawDatasetDNG(dataset_path, align_csv, colorspace, crop_size=crop_size)\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('crw')]\n", + "dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('dng_bayer')]\n", + "dataset.df = dataset.df[(dataset.df.iso >= iso_range[0]) & (dataset.df.iso <= iso_range[1])]\n", + "\n", + "# Split dataset into train and val\n", + "val_size = int(len(dataset) * val_split)\n", + "train_size = len(dataset) - val_size\n", + "torch.manual_seed(42) # For reproducibility\n", + "train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n", + "# Set the validation dataset to use the same crops\n", + "val_dataset = copy.deepcopy(val_dataset)\n", + "val_dataset.dataset.validation = True\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6af0f3a2", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", + "if cosine_annealing:\n", + " sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs,eta_min=lr*1e-6)\n", + "else:\n", + " sched = None\n", + " \n", + "vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), \n", + " feature_layers=[14], \n", + " activation=nn.ReLU\n", + " )\n", + "vfe = vfe.to(device)\n", + "\n", + "loss_fn = ShadowAwareLoss(\n", + " alpha=run_config['alpha'],\n", + " beta=run_config['beta'],\n", + " l1_weight=run_config['l1_weight'],\n", + " ssim_weight=run_config['ssim_weight'],\n", + " tv_weight=run_config['tv_weight'],\n", + " vgg_loss_weight=run_config['vgg_loss_weight'],\n", + " percept_loss_weight=run_config['percept_loss_weight'],\n", + " apply_gamma_fn=apply_gamma_torch,\n", + " vgg_feature_extractor=vfe,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7b3706e", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", + "if cosine_annealing:\n", + " sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs,eta_min=lr*1e-6)\n", + "else:\n", + " sched = None\n", + " \n", + "vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), \n", + " feature_layers=[14], \n", + " activation=nn.ReLU\n", + " )\n", + "vfe = vfe.to(device)\n", + "\n", + "loss_fn = ShadowAwareLoss(\n", + " alpha=run_config['alpha'],\n", + " beta=run_config['beta'],\n", + " l1_weight=run_config['l1_weight'],\n", + " ssim_weight=run_config['ssim_weight'],\n", + " tv_weight=run_config['tv_weight'],\n", + " vgg_loss_weight=run_config['vgg_loss_weight'],\n", + " percept_loss_weight=run_config['percept_loss_weight'],\n", + " apply_gamma_fn=apply_gamma_torch,\n", + " vgg_feature_extractor=vfe,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bc86a06", + "metadata": {}, + "outputs": [], + "source": [ + "with mlflow.start_run(run_name=run_config['run_name']) as run:\n", + " mlflow.log_params(params)\n", + " for epoch in range(num_epochs):\n", + " train_one_epoch(epoch, model, optimizer, train_loader, device, loss_fn, clipping, \n", + " log_interval = 10, sleep=0.0, rggb=rggb, max_batches=0)\n", + " if cosine_annealing:\n", + " sched.step()\n", + " \n", + " mlflow.pytorch.log_model(\n", + " pytorch_model=model,\n", + " name=run_config['run_path'],\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a775d31", + "metadata": {}, + "outputs": [], + "source": [ + "run.info.run_id" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/3_make_script.ipynb b/3_make_script.ipynb new file mode 100644 index 0000000..171ad30 --- /dev/null +++ b/3_make_script.ipynb @@ -0,0 +1,130 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "685f1f4b", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from pathlib import Path\n", + "import mlflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4c87e77", + "metadata": {}, + "outputs": [], + "source": [ + "from src.training.load_config import load_config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f499abe", + "metadata": {}, + "outputs": [], + "source": [ + "run_config = load_config()\n", + "output = Path(run_config['script_path']) / f\"{run_config['run_path']}.pt\"\n", + "experiment = run_config['mlflow_experiment']\n", + "mlflow_path = run_config['mlflow_path']\n", + "rggb = True\n", + "mlflow.set_tracking_uri(f\"file://{mlflow_path}\")\n", + "mlflow.set_experiment(experiment)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0655be14", + "metadata": {}, + "outputs": [], + "source": [ + "output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5eb1679", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "RUN_ID = \"10df0d1d6eba4c6b887086d69ab390a7\" \n", + "ARTIFACT_PATH = run_config['run_path']\n", + "\n", + "model_uri = f\"runs:/{RUN_ID}/{ARTIFACT_PATH}\"\n", + "\n", + "try:\n", + " model = mlflow.pytorch.load_model(model_uri)\n", + " model.eval()\n", + " model = model.to('cpu')\n", + " print(f\"Model successfully loaded from MLflow URI: {model_uri}\")\n", + " \n", + "\n", + "except Exception as e:\n", + " print(f\"Error loading model from MLflow: {e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "913f088c", + "metadata": {}, + "outputs": [], + "source": [ + "run_config['crop_size']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1e4397b", + "metadata": {}, + "outputs": [], + "source": [ + "in_channels = run_config['model_params']['in_channels']\n", + "out_channels = run_config['model_params']['out_channels']\n", + "cond_input = run_config['model_params']['cond_input']\n", + "crop_size = run_config['crop_size']\n", + "\n", + "if run_config['model_params']['rggb']:\n", + " input = torch.rand(1, in_channels, crop_size//2, crop_size//2)\n", + "else:\n", + " input = torch.rand(1, in_channels, crop_size, crop_size)\n", + "\n", + "cond = torch.rand(1, cond_input)\n", + "residual = torch.rand(1, out_channels, crop_size, crop_size)\n", + "traced_script_module = torch.jit.trace(model, (input, cond, residual))\n", + "traced_script_module.save(output)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "OnSight", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/README.md b/README.md index 21cf958..db6bb96 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,122 @@ -# WIP +# Restorer -This is a WIP repo for testing conditioning a unet. More instructions to follow. \ No newline at end of file +## Overview + +This branch contains experimental code for a lightweight, local training workflow to aid in the development of models for **Raw Refinery**. The goal is to enable faster model iteration on low-powered local computers with limited SSD space when access to large GPUs and SSDs are limited. + +While functional, it remains a **work in progress**, primarily intended as a proof-of-concept for faster local experimentation. + +Because the models are designed to be computationally efficient, the main bottleneck in this workflow is **disk I/O**. Raw image files are high-resolution 12-bit images, too large for the laptop’s internal SSD but too slow to stream from an external HDD. To mitigate this, we generate a smaller, compressed dataset suitable for quick pretraining. + +--- + +## Producing the Smaller Dataset + +We pretrain using a compressed **8-bit version** of the noisy raw sensor data and their corresponding **preprocessed (demosaiced)** ground-truth images. + +After testing various formats, **JPEG** proved to be the most convenient and performant: + +* Faster to read than PNG or NumPy-based formats +* Smaller file sizes even at high quality settings +* Sufficiently preserves high-frequency detail in the Bayer sensor data + +*(A possible future improvement would be to store each of the four Bayer channels as separate JPEGs and merge them at runtime, but this implementation opts for simplicity.)* + +### Relevant Files + +* **`download_rawnind.sh`** + Downloads the [RawNIND dataset](https://doi.org/10.14428/DVN/DEQCIM) file-by-file, as it’s too large to fetch directly. You can place it in any directory where you wish the dataset to be stored. + +* **`config.yaml`** + Contains global configuration parameters including file paths, training hyperparameters, and general workflow settings. + +* **`0_produce_small_dataset.ipynb`** + Two-part notebook: + + 1. Iterates through raw images, aligns a demosaiced ground-truth image to each noisy image, and saves aligned pairs as JPEGs. + + * Alignment first uses a **feature-based method** for coarse alignment. + * Followed by an **ECC alignment** for sub-pixel precision. + * The **correlation coefficient** is saved to help identify poorly aligned images. + * Manual inspection is still recommended—alignment quality is critical for effective model training. + 2. Optionally crops the JPEGs to reduce storage requirements and speed up data loading. + +* **`0_align_images.ipynb`** + Re-runs image alignment. Useful for experimenting with alternative alignment techniques or metrics. Not required if the previous notebook has already aligned the dataset. + +--- + +## Pretraining + +With the small dataset prepared, we can begin **local pretraining**. + +Training hyperparameters are stored in `config.yaml`. Typical settings: + +* **Patch size:** 256×256 +* **Batch size:** 2 (for low-memory environments) +* **Optimizer:** Adam +* **Learning rate:** constant (can be scaled for larger GPUs) +* **Epochs:** 75 (sufficient for baseline performance); 150 preferred (~<24 hrs training time) + +MLflow is used to track runs, hyperparameters, and metrics. + +The dataset is split **80/20** into training and validation subsets. +Regularization methods like L2 weight decay, dropout, or strong augmentations often **harm** reconstruction performance; **random cropping** is typically sufficient. + +### Relevant Files + +* **`1_pretrain_model.ipynb`** + Main training notebook. Loads the model, initializes the optimizer, and runs the training loop. + MLflow integration records hyperparameters and results. + *(Validation is not yet integrated directly in the training loop.)* + **To-do:** Add model checkpointing. + +* **`1_validate_model.ipynb`** + Validates trained models. + Separated from the main loop to allow for manual visual inspection of images with different noise levels, which is often more informative than numerical loss metrics. + **To-do:** Save validation artifacts (sample images, metrics) to the MLflow run. + +* **`src/Restorer/Cond_NAF.py`** + Defines the neural network model. + +* **`src/training/train_loop.py`** + Implements the core training and validation loops. + +* **`src/training/losses/ShadowAwareLoss.py`** + Custom loss combining: + + * L1 loss + * Multi-scale SSIM loss + * Perceptual loss (VGG-like) + Designed to emphasize realistic and visually appealing reconstructions. + +--- + +## Fine-Tuning + +Once pretraining converges, we fine-tune the model on **real raw images**. +This step restores high-frequency detail lost during 8-bit compression. + +Fine-tuning uses similar hyperparameters to pretraining, but introduces **Cosine annealing** for smoother learning rate decay. + +This portion of the workflow is still will be added shortly. + +--- + +## Deployment + +To deploy a trained model, we **trace** it into a TorchScript module for seamless integration with the **Raw Refinery** application. + +This is handled in the **`3_make_script.ipynb`** notebook. + +--- + +## Acknowledgments + +> Brummer, Benoit; De Vleeschouwer, Christophe. (2025). +> *Raw Natural Image Noise Dataset.* +> [https://doi.org/10.14428/DVN/DEQCIM](https://doi.org/10.14428/DVN/DEQCIM), Open Data @ UCLouvain, V1. + +> Chen, Liangyu; Chu, Xiaojie; Zhang, Xiangyu; Chen, Jianhao. (2022). +> *NAFNet: Simple Baselines for Image Restoration.* +> [https://doi.org/10.48550/arXiv.2208.04677](https://doi.org/10.48550/arXiv.2208.04677), arXiv, V1. \ No newline at end of file diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..f85074d --- /dev/null +++ b/config.yaml @@ -0,0 +1,59 @@ +# config.yaml +# --- Paths --- +base_data_dir: /Volumes/EasyStore/RAWNIND/ +jpeg_output_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG +cropped_jpeg_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG +cropped_raw_subdir: /Users/ryanmueller/Develop/Cropped_Raw/ +cropped_raw_size: 2000 +align_csv: align_data.csv +secondary_align_csv: align_phase_v2_exposure_corr.csv +script_path: /Volumes/EasyStore/models/traces +mlflow_path: /Volumes/EasyStore/models/mlfow + +# --- Training Params --- +colorspace: lin_rec2020 +device: mps +batch_size: 2 +crop_size: 256 +lr_base: 2.5e-5 +clipping: 1e-2 +num_epochs_pretraining: 75 +num_epochs_finetuning: 20 +val_split: 0.2 +random_seed: 42 +cosine_annealing: False +iso_range: [0, 999999] + +# --- Experiment Settings --- +experiment: NAF_test +mlflow_experiment: NAFNet_variations + +# --- Run Configuration ---: +run_name: NAF_corr_DIST +run_path: NAF_deep_test_align +model_params: + chans: [32, 64, 128, 256, 256, 256] + enc_blk_nums: [2, 2, 2, 3, 4] + middle_blk_num: 12 + dec_blk_nums: [2, 2, 2, 2, 2] + cond_input: 1 + in_channels: 4 + out_channels: 3 + rggb: True + use_CondFuserV2: False + use_add: False + use_CondFuserV3: False + use_attnblock: False + use_CondFuserV4: False + use_NAFBlock0_learned_norm: False + use_input_stats: False + gamma: 2.2 + +# --- Loss Configureation ---: +alpha: 0.2 +beta: 5.0 +l1_weight: 0.16 +ssim_weight: 0.84 +tv_weight: 0.0 +vgg_loss_weight: 0.0 +percept_loss_weight: 0.025 \ No newline at end of file diff --git a/config_demosaicing.yaml b/config_demosaicing.yaml new file mode 100644 index 0000000..bb0f980 --- /dev/null +++ b/config_demosaicing.yaml @@ -0,0 +1,59 @@ +# config.yaml +# --- Paths --- +# base_data_dir: /Volumes/EasyStore/RAWNIND/ +# jpeg_output_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG +# cropped_jpeg_subdir: /Volumes/EasyStore/RAWNIND/JPEGs/Cropped_JPEG +cropped_raw_subdir: /Users/ryanmueller/Develop/Cropped_Raw/ +cropped_raw_size: 2000 +align_csv: align_data.csv +secondary_align_csv: align_phase_v2.csv +script_path: /Users/ryanmueller/Develop/traces/ +mlflow_path: /Users/ryanmueller/Develop/mlflow/ + +# --- Training Params --- +colorspace: lin_rec2020 +device: mps +batch_size: 2 +crop_size: 256 +lr_base: 2.5e-5 +clipping: 1e-2 +num_epochs_pretraining: 5 +num_epochs_finetuning: 20 +val_split: 0.2 +random_seed: 42 +cosine_annealing: False +iso_range: [0, 999999] + +# --- Experiment Settings --- +experiment: Demosaicing +mlflow_experiment: Demosaicing + +# --- Run Configuration ---: +run_name: Demosaicing_test_5x5_just_l1 +run_path: Demosaicing_Test +model_params: + chans: [12, 24] + enc_blk_nums: [1] + middle_blk_num: 1 + dec_blk_nums: [1] + cond_input: 1 + in_channels: 4 + out_channels: 3 + rggb: True + use_CondFuserV2: False + use_add: False + use_CondFuserV3: False + use_attnblock: False + use_CondFuserV4: False + use_NAFBlock0_learned_norm: False + use_input_stats: False + gamma: 1.0 + +# --- Loss Configureation ---: +alpha: 1.0 +beta: 5.0 +l1_weight: 1 +ssim_weight: 0 +tv_weight: 0.0 +vgg_loss_weight: 0.0 +percept_loss_weight: 0.0 \ No newline at end of file diff --git a/download_rawnind.sh b/download_rawnind.sh new file mode 100644 index 0000000..c76b91f --- /dev/null +++ b/download_rawnind.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# Dataverse base URL +BASE_URL="https://dataverse.uclouvain.be" + +# Dataset DOI +DOI="doi:10.14428/DVN/DEQCIM" + +# Metadata file +METADATA_FILE="dataset.json" + +# Fetch dataset metadata if not already downloaded +if [ ! -f "$METADATA_FILE" ]; then + echo "Fetching dataset metadata..." + curl -s "$BASE_URL/api/datasets/:persistentId?persistentId=$DOI" -o "$METADATA_FILE" +else + echo "Using cached metadata file: $METADATA_FILE" +fi + +# Parse and download each file +echo "Starting download..." +jq -c '.data.latestVersion.files[]' "$METADATA_FILE" | while read file; do + FILE_ID=$(echo "$file" | jq '.dataFile.id') + FILENAME=$(echo "$file" | jq -r '.dataFile.filename') + + # Skip file if it already exists + if [ -f "$FILENAME" ]; then + echo "Skipping existing file: $FILENAME" + continue + fi + + echo "Downloading: $FILENAME (ID: $FILE_ID)" + curl -L -C - -o "$FILENAME" "$BASE_URL/api/access/datafile/$FILE_ID" + + # Check if download was successful + if [ $? -eq 0 ]; then + echo "Finished: $FILENAME" + else + echo "Error downloading: $FILENAME — will retry next time" + fi +done + +echo "All files processed." diff --git a/src/Restorer/Cond_CHASPA.py b/src/Restorer/Cond_CHASPA.py new file mode 100644 index 0000000..f7bac68 --- /dev/null +++ b/src/Restorer/Cond_CHASPA.py @@ -0,0 +1,354 @@ +import torch.nn.functional as F +import torch +import torch.nn as nn + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class ConditionedChannelAttention(nn.Module): + def __init__(self, dims, cat_dims): + super().__init__() + in_dim = dims + cat_dims + # self.mlp = nn.Sequential( + # nn.Linear(in_dim, int(in_dim*1.5)), + # nn.GELU(), + # nn.Dropout(0.2), + # nn.Linear(int(in_dim*1.5), dims) + # ) + self.mlp = nn.Sequential(nn.Linear(in_dim, dims)) + self.pool = nn.AdaptiveAvgPool2d(1) + + def forward(self, x, conditioning): + pool = self.pool(x) + conditioning = conditioning.unsqueeze(-1).unsqueeze(-1) + cat_channels = torch.cat([pool, conditioning], dim=1) + cat_channels = cat_channels.permute(0, 2, 3, 1) + ca = self.mlp(cat_channels).permute(0, 3, 1, 2) + + return ca + +class NKA(nn.Module): + def __init__(self, dim, channel_reduction = 8): + super().__init__() + + reduced_channels = dim // channel_reduction + self.proj_1 = nn.Conv2d(dim, reduced_channels, 1, 1, 0) + self.dwconv = nn.Conv2d(reduced_channels, reduced_channels, 3, 1, 1, groups=reduced_channels) + self.proj_2 = nn.Conv2d(reduced_channels, reduced_channels * 2, 1, 1, 0) + self.sg = SimpleGate() + self.attention = nn.Conv2d(reduced_channels, dim, 1, 1, 0) + + def forward(self, x): + B, C, H, W = x.shape + # First projection to a smaller dimension + y = self.proj_1(x) + # DW conv + attn = self.dwconv(y) + # PW to increase channel count for SG + attn = self.proj_2(attn) + # Non-linearity + attn = self.sg(attn) + # Back to original dimensions + out = x * self.attention(attn) + return out + +class CHASPABlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + + self.NKA = NKA(c) + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=c, + kernel_size=3, + padding=1, + stride=1, + groups=c, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv2 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = nn.GroupNorm(1, c) + self.norm2 = nn.GroupNorm(1, c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + x = self.norm1(x) + + # Channel Mixing + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + x = self.dropout2(x) + y = inp + x * self.beta + + #Spatial Mixing + x = self.NKA(self.norm2(y)) + x = self.conv1(x) + x = self.dropout1(x) + + + return (y + x * self.gamma, cond) + + + +class CondSEBlock(nn.Module): + def __init__(self, chan, reduction=16, cond_chan=1): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(chan + cond_chan, chan // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(chan // reduction, chan, bias=False), + nn.Sigmoid() + ) + + def forward(self, x, conditioning): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = torch.cat([y, conditioning], dim=1) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + +class CondFuser(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + x1, x2 = x.chunk(2, dim=1) + return x1 + x2 + + +class Restorer(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[], + chans = [], + cond_input=1, + cond_output=32, + expand_dims=2, + drop_out_rate=0.0, + drop_out_rate_increment=0.0, + rggb = False + ): + super().__init__() + width = chans[0] + + self.expand_dims = expand_dims + self.conditioning_gen = nn.Sequential( + nn.Linear(cond_input, 64), nn.ReLU(), nn.Dropout(drop_out_rate), nn.Linear(64, cond_output), + ) + self.rggb = rggb + if not rggb: + self.intro = nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + else: + self.intro = nn.Sequential( + + nn.Conv2d( + in_channels=in_channels, + out_channels=width * 2 ** 2, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ), + nn.PixelShuffle(2) + ) + + nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + self.ending = nn.Conv2d( + in_channels=width, + out_channels=out_channels, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + self.merges = nn.ModuleList() + + + # for num in enc_blk_nums: + for i in range(len(enc_blk_nums)): + current_chan = chans[i] + next_chan = chans[i + 1] + num = enc_blk_nums[i] + self.encoders.append( + nn.Sequential( + *[ + CHASPABlock(current_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + drop_out_rate += drop_out_rate_increment + self.downs.append(nn.Conv2d(current_chan, next_chan, 2, 2)) + + self.middle_blks = nn.Sequential( + *[ + CHASPABlock(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(middle_blk_num) + ] + ) + + for i in range(len(dec_blk_nums)): + current_chan = chans[-i-1] + next_chan = chans[-i-2] + num = dec_blk_nums[i] + self.ups.append( + nn.Sequential( + nn.Conv2d(current_chan, next_chan * 2 ** 2, 1, bias=False), nn.PixelShuffle(2) + ) + ) + drop_out_rate -= drop_out_rate_increment + self.merges.append(CondFuser(next_chan, cond_chan=cond_output)) + self.decoders.append( + nn.Sequential( + *[ + CHASPABlock(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + + + self.padder_size = 2 ** len(self.encoders) + + def forward(self, inp, cond_in): + # Conditioning: + cond = self.conditioning_gen(cond_in) + + B, C, H, W = inp.shape + if self.rggb: + H = 2 * H + W = 2 * W + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + for encoder, down in zip(self.encoders, self.downs): + x = encoder((x, cond))[0] + encs.append(x) + x = down(x) + + x = self.middle_blks((x, cond))[0] + + for decoder, up, merge, enc_skip in zip(self.decoders, self.ups, self.merges, encs[::-1]): + x = up(x) + x = merge(x, enc_skip, cond) + x = decoder((x, cond))[0] + + x = self.ending(x) + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + +class ModelWrapperFullRGGB(nn.Module): + def __init__(self, **kwargs): + super().__init__() + self.model = Restorer( + **kwargs + ) + + def forward(self, x, cond, residual): + output = self.model(x, cond) + return residual + output + + +def make_full_model_RGGB(model_name = '/Volumes/EasyStore/models/Cond_NAF_variable_layers_cca_merge_unet_sparse_ssim_real_raw_full_RGGB.pt'): + params = {"chans" : [32, 64, 128, 256, 256, 256], + "enc_blk_nums" : [2,2,2,3,4], + "middle_blk_num" : 12, + "dec_blk_nums" : [2, 2, 2, 2, 2], + "cond_input" : 1, + "in_channels" : 4, + "out_channels" : 3, + "rggb": True, + } + model = ModelWrapperFullRGGB(**params) + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model, params + diff --git a/src/Restorer/Cond_NAF.py b/src/Restorer/Cond_NAF.py new file mode 100644 index 0000000..f710cf6 --- /dev/null +++ b/src/Restorer/Cond_NAF.py @@ -0,0 +1,872 @@ +import torch.nn.functional as F +import torch +import torch.nn as nn + + +class LayerNorm2dAdjusted(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x, target_mu, target_var): + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + + y = (x - mu) / torch.sqrt(var + self.eps) + + y = y * torch.sqrt(target_var + self.eps) + target_mu + + weight_view = self.weight.view(1, self.weight.size(0), 1, 1) + bias_view = self.bias.view(1, self.bias.size(0), 1, 1) + + y = weight_view * y + bias_view + return y + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + + y = (x - mu) / torch.sqrt(var + self.eps) + + weight_view = self.weight.view(1, self.weight.size(0), 1, 1) + bias_view = self.bias.view(1, self.bias.size(0), 1, 1) + + y = weight_view * y + bias_view + return y + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class ConditionedChannelAttention(nn.Module): + def __init__(self, dims, cat_dims): + super().__init__() + in_dim = dims + cat_dims + self.mlp = nn.Sequential(nn.Linear(in_dim, dims)) + self.pool = nn.AdaptiveAvgPool2d(1) + + def forward(self, x, conditioning): + pool = self.pool(x) + conditioning = conditioning.unsqueeze(-1).unsqueeze(-1) + cat_channels = torch.cat([pool, conditioning], dim=1) + cat_channels = cat_channels.permute(0, 2, 3, 1) + ca = self.mlp(cat_channels).permute(0, 3, 1, 2) + + return ca + +class CondFuser(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + # self.spa = nn.Conv2d( + # in_channels=chan * 2, + # out_channels=1, + # kernel_size=3, + # padding=1, + # stride=1, + # groups=1, + # bias=True, + # ) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + # spa = torch.sigmoid(self.spa(x)) + + x1, x2 = x.chunk(2, dim=1) + # return x1 * spa + x2 * (1 - spa) + return x1 + x2 + + +class NKA(nn.Module): + def __init__(self, dim, channel_reduction = 8): + super().__init__() + + reduced_channels = dim // channel_reduction + self.proj_1 = nn.Conv2d(dim, reduced_channels, 1, 1, 0) + self.dwconv = nn.Conv2d(reduced_channels, reduced_channels, 3, 1, 1, groups=reduced_channels) + self.proj_2 = nn.Conv2d(reduced_channels, reduced_channels * 2, 1, 1, 0) + self.sg = SimpleGate() + self.attention = nn.Conv2d(reduced_channels, dim, 1, 1, 0) + + def forward(self, x): + B, C, H, W = x.shape + # First projection to a smaller dimension + y = self.proj_1(x) + # DW conv + attn = self.dwconv(y) + # PW to increase channel count for SG + attn = self.proj_2(attn) + # Non-linearity + attn = self.sg(attn) + # Back to original dimensions + out = x * self.attention(attn) + return out + + +class CondFuser(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + x1, x2 = x.chunk(2, dim=1) + return x1 + x2 + +class CondFuserAdd(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + + def forward(self, x1, x2, cond): + return x1 + x2 + +class CondFuserV2(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + self.spa = NKA(chan * 2) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + spa = torch.sigmoid(self.spa(x)) * x + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class CondFuserV3(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + self.spa = nn.Conv2d( + in_channels=chan * 2, + out_channels=1, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + spa = torch.sigmoid(self.spa(x)) + + x1, x2 = x.chunk(2, dim=1) + return x1 * spa + x2 * (1 - spa) + +class CondFuserV4(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + self.pw = nn.Conv2d(chan * 2, chan, 1, stride=1, padding=0, groups=1) + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + x = self.pw(x) + return x + + +class NAFBlock0(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + x = self.conv4(self.norm2(y)) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return (y + x * self.gamma, cond) + + +class NAFBlock0_learned_norm(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.sca_mul = ConditionedChannelAttention(c, cond_chans) + self.sca_add = ConditionedChannelAttention(c, cond_chans) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + normed = self.norm2(y) + + # Input mediated channel attention, obstensibly to mitigate the effects of group norm on flat scenes + x = (1 + self.sca_mul(inp, cond)) * normed + self.sca_add(inp, cond) + + x = self.conv4(x) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + + + return (y + x * self.gamma, cond) + +class NAFBlock0AdjustedNorm(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2dAdjusted(c) + self.norm2 = LayerNorm2dAdjusted(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.sca_mul = ConditionedChannelAttention(c, cond_chans) + self.sca_add = ConditionedChannelAttention(c, cond_chans) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x, mu, var) + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + normed = self.norm2(y, mu, var) + + # Input mediated channel attention, obstensibly to mitigate the effects of group norm on flat scenes + # x = (1 + self.sca_mul(inp, cond)) * normed + self.sca_add(inp, cond) + + x = self.conv4(normed) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return (y + x * self.gamma, cond, mu, var) + + +import torch.nn.functional as F + +class SwiGLU(nn.Module): + + def __init__(self, input_dim, hidden_dim, dropout=0.1): + super().__init__() + self.w1 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1) + self.w2 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1) + self.w3 = nn.Conv2d(hidden_dim, input_dim, 1, 1, 0, 1) + + def forward(self, x): + gate = F.silu(self.w1(x)) + value = self.w2(x) + x = gate * value + + x = self.w3(x) + return x + +class AttnBlock(nn.Module): + def __init__(self, c, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + + self.dw = nn.Conv2d( + in_channels=c, + out_channels=c, + kernel_size=3, + padding=1, + stride=1, + groups=c, + bias=True, + ) + self.nka = NKA(c) + + self.sca = ConditionedChannelAttention(c, cond_chans) + + self.norm = nn.GroupNorm(1, c) + + self.swiglu = SwiGLU(c, int(c * FFN_Expand)) + self.alpha = nn.Parameter(torch.zeros(1, c, 1, 1)) + self.beta = nn.Parameter(torch.zeros(1, c, 1, 1)) + + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = self.dw(inp) + x = self.nka(x) + x = self.sca(x, cond) * x + y = self.norm(inp + self.alpha * x ) + + + x = self.swiglu(y) + x = y + self.beta * x + return (x, cond) + + +class CondSEBlock(nn.Module): + def __init__(self, chan, reduction=16, cond_chan=1): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(chan + cond_chan, chan // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(chan // reduction, chan, bias=False), + nn.Sigmoid() + ) + + def forward(self, x, conditioning): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = torch.cat([y, conditioning], dim=1) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + + +class ConditioningCNN(nn.Module): + def __init__(self, in_channels=4, num_logits=128): + """ + Args: + in_channels (int): Number of input channels (e.g., 3 for RGB). + num_logits (int): The desired size of the output 1D logit vector. + """ + super().__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True), + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True) + ) + + self.logit_head = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(256, num_logits) + ) + def forward(self, x): + x = self.encoder(x) + x = self.logit_head(x) + return x + +class Restorer(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[], + chans = [], + cond_input=1, + cond_output=32, + expand_dims=2, + drop_out_rate=0.0, + drop_out_rate_increment=0.0, + rggb = False, + use_CondFuserV2 = False, + use_add = False, + use_CondFuserV3 = False, + use_CondFuserV4 = False, + use_attnblock = False, + use_NAFBlock0_learned_norm=False, + use_cond_net = False, + cond_net_num = 32, + use_input_stats=False, + use_NAFBlock0AdjustedNorm=False, + ): + super().__init__() + if use_attnblock: + block = AttnBlock + elif use_NAFBlock0_learned_norm: + block = NAFBlock0_learned_norm + elif use_NAFBlock0AdjustedNorm: + block = NAFBlock0AdjustedNorm + else: + block = NAFBlock0 + + width = chans[0] + + self.expand_dims = expand_dims + self.conditioning_gen = nn.Sequential( + nn.Linear(cond_input, 64), nn.ReLU(), nn.Dropout(drop_out_rate), nn.Linear(64, cond_output), + ) + + + self.use_cond_net = use_cond_net + if use_cond_net: + self.cond_net = ConditioningCNN(in_channels=in_channels, num_logits=cond_net_num) + cond_output = cond_output + cond_net_num + + self.use_input_stats = use_input_stats + if use_input_stats: + cond_output = cond_output + in_channels * 2 + + self.rggb = rggb + if not rggb: + self.intro = nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + else: + self.intro = nn.Sequential( + + nn.Conv2d( + in_channels=in_channels, + out_channels=width * 2 ** 2, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ), + nn.PixelShuffle(2) + ) + + nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + self.ending = nn.Conv2d( + in_channels=width, + out_channels=out_channels, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + self.merges = nn.ModuleList() + + + # for num in enc_blk_nums: + for i in range(len(enc_blk_nums)): + current_chan = chans[i] + next_chan = chans[i + 1] + num = enc_blk_nums[i] + self.encoders.append( + nn.Sequential( + *[ + block(current_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + drop_out_rate += drop_out_rate_increment + self.downs.append(nn.Conv2d(current_chan, next_chan, 2, 2)) + + self.middle_blks = nn.Sequential( + *[ + block(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(middle_blk_num) + ] + ) + + for i in range(len(dec_blk_nums)): + current_chan = chans[-i-1] + next_chan = chans[-i-2] + num = dec_blk_nums[i] + self.ups.append( + nn.Sequential( + nn.Conv2d(current_chan, next_chan * 2 ** 2, 1, bias=False), nn.PixelShuffle(2) + ) + ) + drop_out_rate -= drop_out_rate_increment + if use_CondFuserV2: + self.merges.append(CondFuserV2(next_chan, cond_chan=cond_output)) + elif use_add: + self.merges.append(CondFuserAdd(next_chan, cond_chan=cond_output)) + elif use_CondFuserV3: + self.merges.append(CondFuserV3(next_chan, cond_chan=cond_output)) + elif use_CondFuserV4: + self.merges.append(CondFuserV4(next_chan, cond_chan=cond_output)) + else: + self.merges.append(CondFuser(next_chan, cond_chan=cond_output)) + + self.decoders.append( + nn.Sequential( + *[ + block(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + + + self.padder_size = 2 ** len(self.encoders) + + def forward(self, inp, cond_in): + # Conditioning: + cond = self.conditioning_gen(cond_in) + + # if self.use_cond_net: + # extra_cond = self.cond_net(inp) + # cond = torch.cat([cond, extra_cond], dim=1) + # if self.use_input_stats: + # mu = inp.mean((2,3), keepdim=True) + # var = (inp - mu).pow(2).mean((2,3), keepdim=False) + # mu = mu.squeeze(-1).squeeze(-1) + # cond = torch.cat([cond, mu, var], dim=1) + + B, C, H, W = inp.shape + if self.rggb: + H = 2 * H + W = 2 * W + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + for encoder, down in zip(self.encoders, self.downs): + x = encoder((x, cond))[0] + encs.append(x) + x = down(x) + + x = self.middle_blks((x, cond))[0] + + for decoder, up, merge, enc_skip in zip(self.decoders, self.ups, self.merges, encs[::-1]): + x = up(x) + x = merge(x, enc_skip, cond) + x = decoder((x, cond))[0] + + x = self.ending(x) + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + +class ModelWrapper(nn.Module): + def __init__(self, **kwargs): + self.gamma = 1 + if 'gamma' in kwargs: + self.gamma = kwargs.pop('gamma') + super().__init__() + self.model = Restorer( + **kwargs + ) + + def forward(self, x, cond, residual): + x = x.clip(0, 1) ** (1. / self.gamma) + residual = residual.clip(0, 1) ** (1. / self.gamma) + output = self.model(x, cond) + output = (residual + output).clip(0, 1) ** (self.gamma) + return output + + +def make_full_model_RGGB(params, model_name=None): + model = ModelWrapper(**params) + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + + + +from src.Restorer.Cond_NAF_demosaic import DemosaicingFromRGGB + +class ModelWrapperNoRes(nn.Module): + def __init__(self, **kwargs): + super().__init__() + + self.gamma = 1 + if 'gamma' in kwargs: + self.gamma = kwargs.pop('gamma') + + self.demosaicer = DemosaicingFromRGGB() + self.model = Restorer( + **kwargs + ) + + def forward(self, rggb, cond): + rggb = rggb.clip(0, 1) ** (1. / self.gamma) + debayered = self.demosaicer(rggb, cond) + debayered = debayered.clip(0, 1) ** (1. / self.gamma) + output = self.model(rggb, cond) + output = (debayered + output).clip(0, 1) ** (self.gamma) + return output + + +def make_full_model_RGGB_NoRes(params, model_name=None): + model = ModelWrapperNoRes(**params) + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + diff --git a/src/Restorer/Cond_NAF_demosaic.py b/src/Restorer/Cond_NAF_demosaic.py new file mode 100644 index 0000000..fd872d4 --- /dev/null +++ b/src/Restorer/Cond_NAF_demosaic.py @@ -0,0 +1,938 @@ +import torch.nn.functional as F +import torch +import torch.nn as nn +from src.Restorer.debayering import Debayer5x5 + + +class LayerNorm2dAdjusted(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x, target_mu, target_var): + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + + y = (x - mu) / torch.sqrt(var + self.eps) + + y = y * torch.sqrt(target_var + self.eps) + target_mu + + weight_view = self.weight.view(1, self.weight.size(0), 1, 1) + bias_view = self.bias.view(1, self.bias.size(0), 1, 1) + + y = weight_view * y + bias_view + return y + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + + y = (x - mu) / torch.sqrt(var + self.eps) + + weight_view = self.weight.view(1, self.weight.size(0), 1, 1) + bias_view = self.bias.view(1, self.bias.size(0), 1, 1) + + y = weight_view * y + bias_view + return y + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class ConditionedChannelAttention(nn.Module): + def __init__(self, dims, cat_dims): + super().__init__() + in_dim = dims + cat_dims + self.mlp = nn.Sequential(nn.Linear(in_dim, dims)) + self.pool = nn.AdaptiveAvgPool2d(1) + + def forward(self, x, conditioning): + pool = self.pool(x) + conditioning = conditioning.unsqueeze(-1).unsqueeze(-1) + cat_channels = torch.cat([pool, conditioning], dim=1) + cat_channels = cat_channels.permute(0, 2, 3, 1) + ca = self.mlp(cat_channels).permute(0, 3, 1, 2) + + return ca + +class CondFuser(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + # self.spa = nn.Conv2d( + # in_channels=chan * 2, + # out_channels=1, + # kernel_size=3, + # padding=1, + # stride=1, + # groups=1, + # bias=True, + # ) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + # spa = torch.sigmoid(self.spa(x)) + + x1, x2 = x.chunk(2, dim=1) + # return x1 * spa + x2 * (1 - spa) + return x1 + x2 + + +class NKA(nn.Module): + def __init__(self, dim, channel_reduction = 8): + super().__init__() + + reduced_channels = dim // channel_reduction + self.proj_1 = nn.Conv2d(dim, reduced_channels, 1, 1, 0) + self.dwconv = nn.Conv2d(reduced_channels, reduced_channels, 3, 1, 1, groups=reduced_channels) + self.proj_2 = nn.Conv2d(reduced_channels, reduced_channels * 2, 1, 1, 0) + self.sg = SimpleGate() + self.attention = nn.Conv2d(reduced_channels, dim, 1, 1, 0) + + def forward(self, x): + B, C, H, W = x.shape + # First projection to a smaller dimension + y = self.proj_1(x) + # DW conv + attn = self.dwconv(y) + # PW to increase channel count for SG + attn = self.proj_2(attn) + # Non-linearity + attn = self.sg(attn) + # Back to original dimensions + out = x * self.attention(attn) + return out + + +class CondFuser(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + x1, x2 = x.chunk(2, dim=1) + return x1 + x2 + +class CondFuserAdd(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + + def forward(self, x1, x2, cond): + return x1 + x2 + +class CondFuserV2(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + self.spa = NKA(chan * 2) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + spa = torch.sigmoid(self.spa(x)) * x + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class CondFuserV3(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + self.spa = nn.Conv2d( + in_channels=chan * 2, + out_channels=1, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + spa = torch.sigmoid(self.spa(x)) + + x1, x2 = x.chunk(2, dim=1) + return x1 * spa + x2 * (1 - spa) + +class CondFuserV4(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + self.pw = nn.Conv2d(chan * 2, chan, 1, stride=1, padding=0, groups=1) + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + x = self.pw(x) + return x + + +class NAFBlock0(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + x = self.conv4(self.norm2(y)) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return (y + x * self.gamma, cond) + + +class NAFBlock0_learned_norm(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.sca_mul = ConditionedChannelAttention(c, cond_chans) + self.sca_add = ConditionedChannelAttention(c, cond_chans) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + normed = self.norm2(y) + + # Input mediated channel attention, obstensibly to mitigate the effects of group norm on flat scenes + x = (1 + self.sca_mul(inp, cond)) * normed + self.sca_add(inp, cond) + + x = self.conv4(x) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + + + return (y + x * self.gamma, cond) + +class NAFBlock0AdjustedNorm(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2dAdjusted(c) + self.norm2 = LayerNorm2dAdjusted(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.sca_mul = ConditionedChannelAttention(c, cond_chans) + self.sca_add = ConditionedChannelAttention(c, cond_chans) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x, mu, var) + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + normed = self.norm2(y, mu, var) + + # Input mediated channel attention, obstensibly to mitigate the effects of group norm on flat scenes + # x = (1 + self.sca_mul(inp, cond)) * normed + self.sca_add(inp, cond) + + x = self.conv4(normed) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return (y + x * self.gamma, cond, mu, var) + + +import torch.nn.functional as F + +class SwiGLU(nn.Module): + + def __init__(self, input_dim, hidden_dim, dropout=0.1): + super().__init__() + self.w1 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1) + self.w2 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1) + self.w3 = nn.Conv2d(hidden_dim, input_dim, 1, 1, 0, 1) + + def forward(self, x): + gate = F.silu(self.w1(x)) + value = self.w2(x) + x = gate * value + + x = self.w3(x) + return x + +class AttnBlock(nn.Module): + def __init__(self, c, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + + self.dw = nn.Conv2d( + in_channels=c, + out_channels=c, + kernel_size=3, + padding=1, + stride=1, + groups=c, + bias=True, + ) + self.nka = NKA(c) + + self.sca = ConditionedChannelAttention(c, cond_chans) + + self.norm = nn.GroupNorm(1, c) + + self.swiglu = SwiGLU(c, int(c * FFN_Expand)) + self.alpha = nn.Parameter(torch.zeros(1, c, 1, 1)) + self.beta = nn.Parameter(torch.zeros(1, c, 1, 1)) + + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = self.dw(inp) + x = self.nka(x) + x = self.sca(x, cond) * x + y = self.norm(inp + self.alpha * x ) + + + x = self.swiglu(y) + x = y + self.beta * x + return (x, cond) + + +class CondSEBlock(nn.Module): + def __init__(self, chan, reduction=16, cond_chan=1): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(chan + cond_chan, chan // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(chan // reduction, chan, bias=False), + nn.Sigmoid() + ) + + def forward(self, x, conditioning): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = torch.cat([y, conditioning], dim=1) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + + +class ConditioningCNN(nn.Module): + def __init__(self, in_channels=4, num_logits=128): + """ + Args: + in_channels (int): Number of input channels (e.g., 3 for RGB). + num_logits (int): The desired size of the output 1D logit vector. + """ + super().__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True), + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3, stride=2, padding='same'), + nn.ReLU(inplace=True) + ) + + self.logit_head = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(256, num_logits) + ) + def forward(self, x): + x = self.encoder(x) + x = self.logit_head(x) + return x + +class Restorer(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[], + chans = [], + cond_input=1, + cond_output=32, + expand_dims=2, + drop_out_rate=0.0, + drop_out_rate_increment=0.0, + rggb = False, + use_CondFuserV2 = False, + use_add = False, + use_CondFuserV3 = False, + use_CondFuserV4 = False, + use_attnblock = False, + use_NAFBlock0_learned_norm=False, + use_cond_net = False, + cond_net_num = 32, + use_input_stats=False, + use_NAFBlock0AdjustedNorm=False, + ): + super().__init__() + if use_attnblock: + block = AttnBlock + elif use_NAFBlock0_learned_norm: + block = NAFBlock0_learned_norm + elif use_NAFBlock0AdjustedNorm: + block = NAFBlock0AdjustedNorm + else: + block = NAFBlock0 + + width = chans[0] + + self.expand_dims = expand_dims + self.conditioning_gen = nn.Sequential( + nn.Linear(cond_input, 64), nn.ReLU(), nn.Dropout(drop_out_rate), nn.Linear(64, cond_output), + ) + + + self.use_cond_net = use_cond_net + if use_cond_net: + self.cond_net = ConditioningCNN(in_channels=in_channels, num_logits=cond_net_num) + cond_output = cond_output + cond_net_num + + self.use_input_stats = use_input_stats + if use_input_stats: + cond_output = cond_output + in_channels * 2 + + self.rggb = rggb + if not rggb: + self.intro = nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + else: + self.intro = nn.Sequential( + + nn.Conv2d( + in_channels=in_channels, + out_channels=width * 2 ** 2, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ), + nn.PixelShuffle(2) + ) + + nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + self.ending = nn.Conv2d( + in_channels=width, + out_channels=out_channels, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + self.merges = nn.ModuleList() + + + # for num in enc_blk_nums: + for i in range(len(enc_blk_nums)): + current_chan = chans[i] + next_chan = chans[i + 1] + num = enc_blk_nums[i] + self.encoders.append( + nn.Sequential( + *[ + block(current_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + drop_out_rate += drop_out_rate_increment + self.downs.append(nn.Conv2d(current_chan, next_chan, 2, 2)) + + self.middle_blks = nn.Sequential( + *[ + block(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(middle_blk_num) + ] + ) + + for i in range(len(dec_blk_nums)): + current_chan = chans[-i-1] + next_chan = chans[-i-2] + num = dec_blk_nums[i] + self.ups.append( + nn.Sequential( + nn.Conv2d(current_chan, next_chan * 2 ** 2, 1, bias=False), nn.PixelShuffle(2) + ) + ) + drop_out_rate -= drop_out_rate_increment + if use_CondFuserV2: + self.merges.append(CondFuserV2(next_chan, cond_chan=cond_output)) + elif use_add: + self.merges.append(CondFuserAdd(next_chan, cond_chan=cond_output)) + elif use_CondFuserV3: + self.merges.append(CondFuserV3(next_chan, cond_chan=cond_output)) + elif use_CondFuserV4: + self.merges.append(CondFuserV4(next_chan, cond_chan=cond_output)) + else: + self.merges.append(CondFuser(next_chan, cond_chan=cond_output)) + + self.decoders.append( + nn.Sequential( + *[ + block(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + + + self.padder_size = 2 ** len(self.encoders) + + self.alpha = nn.Parameter(torch.zeros(1, 1, 1, 1)) + + def forward(self, inp, cond_in): + # Conditioning: + cond = self.conditioning_gen(cond_in) + + # if self.use_cond_net: + # extra_cond = self.cond_net(inp) + # cond = torch.cat([cond, extra_cond], dim=1) + # if self.use_input_stats: + # mu = inp.mean((2,3), keepdim=True) + # var = (inp - mu).pow(2).mean((2,3), keepdim=False) + # mu = mu.squeeze(-1).squeeze(-1) + # cond = torch.cat([cond, mu, var], dim=1) + + B, C, H, W = inp.shape + if self.rggb: + H = 2 * H + W = 2 * W + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + for encoder, down in zip(self.encoders, self.downs): + x = encoder((x, cond))[0] + encs.append(x) + x = down(x) + + x = self.middle_blks((x, cond))[0] + + for decoder, up, merge, enc_skip in zip(self.decoders, self.ups, self.merges, encs[::-1]): + x = up(x) + x = merge(x, enc_skip, cond) + x = decoder((x, cond))[0] + + x = self.ending(x) * self.alpha + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + +class ModelWrapper(nn.Module): + def __init__(self, **kwargs): + self.gamma = 1 + if 'gamma' in kwargs: + self.gamma = kwargs.pop('gamma') + super().__init__() + self.model = Restorer( + **kwargs + ) + + def forward(self, x, cond, residual): + x = x.clip(0, 1) ** (1. / self.gamma) + residual = residual.clip(0, 1) ** (1. / self.gamma) + output = self.model(x, cond) + output = (residual + output).clip(0, 1) ** (self.gamma) + return output + + +def make_full_model_RGGB(params, model_name=None): + model = ModelWrapper(**params) + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + + +class DemosaicingModelWrapper(nn.Module): + def __init__(self, **kwargs): + self.gamma = 1 + if 'gamma' in kwargs: + self.gamma = kwargs.pop('gamma') + super().__init__() + self.demosaicing = Debayer5x5() + self.model = Restorer( + **kwargs + ) + + + def forward(self, x, cond): + x = x.clip(0, 1) ** (1. / self.gamma) + bayer = nn.functional.pixel_shuffle(x, 2) + debayered = self.demosaicing(bayer) + output = (self.model(x, cond) + debayered).clip(0,1) ** (self.gamma) + return output + + +def make_full_model_RGGB_Demosaicing(params, model_name=None): + + model = DemosaicingModelWrapper(**params) + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + + +class DemosaicingModelLW(nn.Module): + def __init__(self, **kwargs): + super().__init__() + self.demosaicing = Debayer5x5() + self.ps = nn.PixelShuffle(2) + self.us = nn.PixelUnshuffle(2) + self.block = NAFBlock0(16, cond_chans=1) + self.pw_out = nn.Conv2d(4, 3, 1) + + + def forward(self, rggb, cond): + + bayer = self.ps(rggb, 2) + debayered = self.demosaicing(bayer) + shuffled_bayer = self.us(debayered) + x = torch.cat([rggb, shuffled_bayer], dim=1) + x = self.block((x, cond)) + x = self.ps(x) + x = self.pw_out(x) + output = x + debayered + return output + + + +class DemosaicingModelWrapperSequential(nn.Module): + def __init__(self, **kwargs): + self.gamma = 1 + if 'gamma' in kwargs: + self.gamma = kwargs.pop('gamma') + kwargs['rggb'] = False + kwargs['in_channels'] = 3 + super().__init__() + self.demosaicing = Debayer5x5() + self.model = Restorer( + **kwargs + ) + + + def forward(self, x, cond): + x = x.clip(0, 1) ** (1. / self.gamma) + bayer = nn.functional.pixel_shuffle(x, 2) + debayered = self.demosaicing(bayer) + output = (self.model(debayered, cond) + debayered).clip(0,1) ** (self.gamma) + return output + + +def make_full_model_RGGB_Demosaicing_Sequential(params, model_name=None): + + model = DemosaicingModelWrapperSequential(**params) + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + + + +class DemosaicingFromRGGB(nn.Module): + def __init__(self): + super().__init__() + self.demosaicing = Debayer5x5() + + def forward(self, x, cond): + bayer = nn.functional.pixel_shuffle(x, 2) + debayered = self.demosaicing(bayer) + return debayered + \ No newline at end of file diff --git a/src/Restorer/Cond_NAF_original.py b/src/Restorer/Cond_NAF_original.py new file mode 100644 index 0000000..87185d4 --- /dev/null +++ b/src/Restorer/Cond_NAF_original.py @@ -0,0 +1,393 @@ +import torch.nn.functional as F +import torch +import torch.nn as nn + +class LayerNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return ( + gx, + (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), + grad_output.sum(dim=3).sum(dim=2).sum(dim=0), + None, + ) + + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class ConditionedChannelAttention(nn.Module): + def __init__(self, dims, cat_dims): + super().__init__() + in_dim = dims + cat_dims + # self.mlp = nn.Sequential( + # nn.Linear(in_dim, int(in_dim*1.5)), + # nn.GELU(), + # nn.Dropout(0.2), + # nn.Linear(int(in_dim*1.5), dims) + # ) + self.mlp = nn.Sequential(nn.Linear(in_dim, dims)) + self.pool = nn.AdaptiveAvgPool2d(1) + + def forward(self, x, conditioning): + pool = self.pool(x) + conditioning = conditioning.unsqueeze(-1).unsqueeze(-1) + cat_channels = torch.cat([pool, conditioning], dim=1) + cat_channels = cat_channels.permute(0, 2, 3, 1) + ca = self.mlp(cat_channels).permute(0, 3, 1, 2) + + return ca + + +class NAFBlock0(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + x = self.conv4(self.norm2(y)) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return (y + x * self.gamma, cond) + + +class CondSEBlock(nn.Module): + def __init__(self, chan, reduction=16, cond_chan=1): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(chan + cond_chan, chan // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(chan // reduction, chan, bias=False), + nn.Sigmoid() + ) + + def forward(self, x, conditioning): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = torch.cat([y, conditioning], dim=1) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + +class CondFuser(nn.Module): + def __init__(self, chan, cond_chan=1): + super().__init__() + self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + + def forward(self, x1, x2, cond): + x = torch.cat([x1, x2], dim=1) + x = self.cca(x, cond) * x + x1, x2 = x.chunk(2, dim=1) + return x1 + x2 + + +class Restorer(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[], + chans = [], + cond_input=1, + cond_output=32, + expand_dims=2, + drop_out_rate=0.0, + drop_out_rate_increment=0.0, + rggb = False + ): + super().__init__() + width = chans[0] + + self.expand_dims = expand_dims + self.conditioning_gen = nn.Sequential( + nn.Linear(cond_input, 64), nn.ReLU(), nn.Dropout(drop_out_rate), nn.Linear(64, cond_output), + ) + self.rggb = rggb + if not rggb: + self.intro = nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + else: + self.intro = nn.Sequential( + + nn.Conv2d( + in_channels=in_channels, + out_channels=width * 2 ** 2, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ), + nn.PixelShuffle(2) + ) + + nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + self.ending = nn.Conv2d( + in_channels=width, + out_channels=out_channels, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + # self.merges = nn.ModuleList() + + + # for num in enc_blk_nums: + for i in range(len(enc_blk_nums)): + current_chan = chans[i] + next_chan = chans[i + 1] + num = enc_blk_nums[i] + self.encoders.append( + nn.Sequential( + *[ + NAFBlock0(current_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + drop_out_rate += drop_out_rate_increment + self.downs.append(nn.Conv2d(current_chan, next_chan, 2, 2)) + + self.middle_blks = nn.Sequential( + *[ + NAFBlock0(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(middle_blk_num) + ] + ) + + for i in range(len(dec_blk_nums)): + current_chan = chans[-i-1] + next_chan = chans[-i-2] + num = dec_blk_nums[i] + self.ups.append( + nn.Sequential( + nn.Conv2d(current_chan, next_chan * 2 ** 2, 1, bias=False), nn.PixelShuffle(2) + ) + ) + drop_out_rate -= drop_out_rate_increment + # self.merges.append(CondFuser(next_chan, cond_chan=cond_output)) + self.decoders.append( + nn.Sequential( + *[ + NAFBlock0(next_chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + + + self.padder_size = 2 ** len(self.encoders) + + def forward(self, inp, cond_in): + # Conditioning: + cond = self.conditioning_gen(cond_in) + + B, C, H, W = inp.shape + if self.rggb: + H = 2 * H + W = 2 * W + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + for encoder, down in zip(self.encoders, self.downs): + x = encoder((x, cond))[0] + encs.append(x) + x = down(x) + + x = self.middle_blks((x, cond))[0] + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + # x = merge(x, enc_skip, cond) + x = decoder((x, cond))[0] + + x = self.ending(x) + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + + +class ModelWrapperFullRGGB(nn.Module): + def __init__(self): + super().__init__() + self.model = Restorer( + chans = [64, 128, 256, 512, 1024], + enc_blk_nums = [2,2,4,8], + middle_blk_num = 12, + dec_blk_nums = [2, 2, 2, 2], + cond_input = 1, + cond_output=1, + in_channels = 4, + out_channels = 3, + rggb=True, + ) + + def forward(self, x, cond, residual): + output = self.model(x, cond) + return residual + output + + +def make_full_model_RGGB(model_name = None): + model = ModelWrapperFullRGGB() + if not model_name is None: + state_dict = torch.load(model_name, map_location="cpu") + model.load_state_dict(state_dict) + return model + +2+2+6+8+12+2+2+2+2 diff --git a/src/Restorer/Cond_NAF_ps.py b/src/Restorer/Cond_NAF_ps.py new file mode 100644 index 0000000..6c79a18 --- /dev/null +++ b/src/Restorer/Cond_NAF_ps.py @@ -0,0 +1,310 @@ +import torch.nn.functional as F +import torch +import torch.nn as nn + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + + y = (x - mu) / torch.sqrt(var + self.eps) + + weight_view = self.weight.view(1, self.weight.size(0), 1, 1) + bias_view = self.bias.view(1, self.bias.size(0), 1, 1) + + y = weight_view * y + bias_view + return y + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + + +class ConditionedChannelAttention(nn.Module): + def __init__(self, dims, cat_dims): + super().__init__() + in_dim = dims + cat_dims + # self.mlp = nn.Sequential( + # nn.Linear(in_dim, int(in_dim*1.5)), + # nn.GELU(), + # nn.Dropout(0.2), + # nn.Linear(int(in_dim*1.5), dims) + # ) + self.mlp = nn.Sequential(nn.Linear(in_dim, dims)) + self.pool = nn.AdaptiveAvgPool2d(1) + + def forward(self, x, conditioning): + pool = self.pool(x) + conditioning = conditioning.unsqueeze(-1).unsqueeze(-1) + cat_channels = torch.cat([pool, conditioning], dim=1) + cat_channels = cat_channels.permute(0, 2, 3, 1) + ca = self.mlp(cat_channels).permute(0, 3, 1, 2) + + return ca + + +class NAFBlock0(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, input): + inp = input[0] + cond = input[1] + + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + x = self.conv4(self.norm2(y)) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return (y + x * self.gamma, cond) + +class Restorer(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + width=16, + middle_blk_num=1, + enc_blk_nums=[], + dec_blk_nums=[], + cond_input=1, + cond_output=32, + expand_dims=2, + drop_out_rate=0.0, + drop_out_rate_increment=0.0 + ): + super().__init__() + + self.expand_dims = expand_dims + self.conditioning_gen = nn.Sequential( + nn.Linear(cond_input, 64), nn.ReLU(), nn.Dropout(drop_out_rate), nn.Linear(64, cond_output), + ) + + self.intro = nn.Conv2d( + in_channels=in_channels, + out_channels=width, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + self.ending = nn.Conv2d( + in_channels=width, + out_channels=out_channels, + kernel_size=3, + padding=1, + stride=1, + groups=1, + bias=True, + ) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + + chan = width + # for num in enc_blk_nums: + for i in range(len(enc_blk_nums)): + num = enc_blk_nums[i] + self.encoders.append( + nn.Sequential( + *[ + NAFBlock0(chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + drop_out_rate += drop_out_rate_increment + self.downs.append(nn.Conv2d(chan, 2 * chan, 2, 2)) + chan = chan * 2 + + self.middle_blks = nn.Sequential( + *[ + NAFBlock0(chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(middle_blk_num) + ] + ) + + for i in range(len(dec_blk_nums)): + num = dec_blk_nums[i] + self.ups.append( + nn.Sequential( + nn.Conv2d(chan, chan * 2, 1, bias=False), nn.PixelShuffle(2) + ) + ) + drop_out_rate -= drop_out_rate_increment + chan = chan // 2 + self.decoders.append( + nn.Sequential( + *[ + NAFBlock0(chan, cond_chans=cond_output, drop_out_rate=drop_out_rate) + for _ in range(num) + ] + ) + ) + + + self.padder_size = 2 ** len(self.encoders) + + def forward(self, inp, cond_in): + # Conditioning: + cond = self.conditioning_gen(cond_in) + + B, C, H, W = inp.shape + inp = self.check_image_size(inp) + + x = self.intro(inp) + + encs = [] + for encoder, down in zip(self.encoders, self.downs): + x = encoder((x, cond))[0] + encs.append(x) + x = down(x) + + x = self.middle_blks((x, cond))[0] + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + x = decoder((x, cond))[0] + + x = self.ending(x) + + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + + + +class AddPixelShuffle(nn.Module): + def __init__(self, model, in_channels=4, out_channels=3): + super().__init__() + self.model = model + self.ps = nn.PixelShuffle(2) + + def forward(self, x, iso): + x = self.model(x, iso) + x = self.ps(x) + return x + + + + +# Get the model +def load_model(weight_file_path): + + model = Restorer( + width=64, + enc_blk_nums = [2, 2, 4, 8], + middle_blk_num = 12, + dec_blk_nums = [2, 2, 2, 2], + cond_input = 2, + in_channels = 4, + out_channels = 12, + ) + model = AddPixelShuffle(model) + + state_dict = torch.load(weight_file_path, map_location=torch.device('cpu')) + model.load_state_dict(state_dict) + return model diff --git a/src/Restorer/Restorer.py b/src/Restorer/Restorer.py index 7f39126..9ebb8e8 100644 --- a/src/Restorer/Restorer.py +++ b/src/Restorer/Restorer.py @@ -453,3 +453,105 @@ def forward(self, x, iso): x = self.model(x, iso) x = self.ps(x) return x + + + + +class NAFBlockCond(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d( + in_channels=c, + out_channels=dw_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv2 = nn.Conv2d( + in_channels=dw_channel, + out_channels=dw_channel, + kernel_size=3, + padding=1, + stride=1, + groups=dw_channel, + bias=True, + ) + self.conv3 = nn.Conv2d( + in_channels=dw_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # Simplified Channel Attention + self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d( + in_channels=c, + out_channels=ffn_channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + self.conv5 = nn.Conv2d( + in_channels=ffn_channel // 2, + out_channels=c, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + # self.grn = GRN(ffn_channel // 2) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + self.dropout2 = ( + nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity() + ) + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp, cond): + + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x, cond) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + # Channel Mixing + x = self.conv4(self.norm2(y)) + x = self.sg(x) + # x = self.grn(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma, cond diff --git a/src/Restorer/debayering.py b/src/Restorer/debayering.py new file mode 100644 index 0000000..79b1bd0 --- /dev/null +++ b/src/Restorer/debayering.py @@ -0,0 +1,381 @@ +# Adapted from https://github.com/cheind/pytorch-debayer/blob/master/debayer/modules.py + +import torch +import torch.nn +import torch.nn.functional + +import enum + + +class Layout(enum.Enum): + """Possible Bayer color filter array layouts. + + The value of each entry is the color index (R=0,G=1,B=2) + within a 2x2 Bayer block. + """ + + RGGB = (0, 1, 1, 2) + GRBG = (1, 0, 2, 1) + GBRG = (1, 2, 0, 1) + BGGR = (2, 1, 1, 0) + +class Debayer3x3(torch.nn.Module): + """Demosaicing of Bayer images using 3x3 convolutions. + + Compared to Debayer2x2 this method does not use upsampling. + Instead, we identify five 3x3 interpolation kernels that + are sufficient to reconstruct every color channel at every + pixel location. + + We convolve the image with these 5 kernels using stride=1 + and a one pixel reflection padding. Finally, we gather + the correct channel values for each pixel location. Todo so, + we recognize that the Bayer pattern repeats horizontally and + vertically every 2 pixels. Therefore, we define the correct + index lookups for a 2x2 grid cell and then repeat to image + dimensions. + """ + + def __init__(self, layout: Layout = Layout.RGGB): + super(Debayer3x3, self).__init__() + self.layout = layout + # fmt: off + self.kernels = torch.nn.Parameter( + torch.tensor( + [ + [0, 0.25, 0], + [0.25, 0, 0.25], + [0, 0.25, 0], + + [0.25, 0, 0.25], + [0, 0, 0], + [0.25, 0, 0.25], + + [0, 0, 0], + [0.5, 0, 0.5], + [0, 0, 0], + + [0, 0.5, 0], + [0, 0, 0], + [0, 0.5, 0], + ] + ).view(4, 1, 3, 3), + requires_grad=False, + ) + # fmt: on + + self.index = torch.nn.Parameter( + self._index_from_layout(layout), + requires_grad=False, + ) + + def forward(self, x): + """Debayer image. + + Parameters + ---------- + x : Bx1xHxW tensor + Images to debayer + + Returns + ------- + rgb : Bx3xHxW tensor + Color images in RGB channel order. + """ + B, C, H, W = x.shape + + xpad = torch.nn.functional.pad(x, (1, 1, 1, 1), mode="reflect") + c = torch.nn.functional.conv2d(xpad, self.kernels, stride=1) + c = torch.cat((c, x), 1) # Concat with input to give identity kernel Bx5xHxW + + rgb = torch.gather( + c, + 1, + self.index.repeat( + 1, + 1, + torch.div(H, 2, rounding_mode="floor"), + torch.div(W, 2, rounding_mode="floor"), + ).expand( + B, -1, -1, -1 + ), # expand in batch is faster than repeat + ) + return rgb + + def _index_from_layout(self, layout: Layout) -> torch.Tensor: + """Returns a 1x3x2x2 index tensor for each color RGB in a 2x2 bayer tile. + + Note, the index corresponding to the identity kernel is 4, which will be + correct after concatenating the convolved output with the input image. + """ + # ... + # ... b g b g ... + # ... g R G r ... + # ... b G B g ... + # ... g r g r ... + # ... + # fmt: off + rggb = torch.tensor( + [ + # dest channel r + [4, 2], # pixel is R,G1 + [3, 1], # pixel is G2,B + # dest channel g + [0, 4], # pixel is R,G1 + [4, 0], # pixel is G2,B + # dest channel b + [1, 3], # pixel is R,G1 + [2, 4], # pixel is G2,B + ] + ).view(1, 3, 2, 2) + # fmt: on + return { + Layout.RGGB: rggb, + Layout.GRBG: torch.roll(rggb, 1, -1), + Layout.GBRG: torch.roll(rggb, 1, -2), + Layout.BGGR: torch.roll(rggb, (1, 1), (-1, -2)), + }.get(layout) + + +class Debayer2x2(torch.nn.Module): + """Fast demosaicing of Bayer images using 2x2 convolutions. + + This method uses 3 kernels of size 2x2 and stride 2. Each kernel + corresponds to a single color RGB. For R and B the corresponding + value from each 2x2 Bayer block is taken according to the layout. + For G, G1 and G2 are averaged. The resulting image has half width/ + height and is upsampled by a factor of 2. + """ + + def __init__(self, layout: Layout = Layout.RGGB): + super(Debayer2x2, self).__init__() + self.layout = layout + + self.kernels = torch.nn.Parameter( + self._kernels_from_layout(layout), + requires_grad=False, + ) + + def forward(self, x): + """Debayer image. + + Parameters + ---------- + x : Bx1xHxW tensor + Images to debayer + + Returns + ------- + rgb : Bx3xHxW tensor + Color images in RGB channel order. + """ + x = torch.nn.functional.conv2d(x, self.kernels, stride=2) + + x = torch.nn.functional.interpolate( + x, scale_factor=2, mode="bilinear", align_corners=False + ) + return x + + def _kernels_from_layout(self, layout: Layout) -> torch.Tensor: + v = torch.tensor(layout.value).reshape(2, 2) + r = torch.zeros(2, 2) + r[v == 0] = 1.0 + + g = torch.zeros(2, 2) + g[v == 1] = 0.5 + + b = torch.zeros(2, 2) + b[v == 2] = 1.0 + + k = torch.stack((r, g, b), 0).unsqueeze(1) # 3x1x2x2 + return k + + +class DebayerSplit(torch.nn.Module): + """Demosaicing of Bayer images using 3x3 green convolution and red,blue upsampling. + Requires Bayer layout `Layout.RGGB`. + """ + + def __init__(self, layout: Layout = Layout.RGGB): + super().__init__() + if layout != Layout.RGGB: + raise NotImplementedError("DebayerSplit only implemented for RGGB layout.") + self.layout = layout + + self.pad = torch.nn.ReflectionPad2d(1) + self.kernel = torch.nn.Parameter( + torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]])[None, None] * 0.25 + ) + + def forward(self, x): + """Debayer image. + + Parameters + ---------- + x : Bx1xHxW tensor + Images to debayer + + Returns + ------- + rgb : Bx3xHxW tensor + Color images in RGB channel order. + """ + B, _, H, W = x.shape + red = x[:, :, ::2, ::2] + blue = x[:, :, 1::2, 1::2] + + green = torch.nn.functional.conv2d(self.pad(x), self.kernel) + green[:, :, ::2, 1::2] = x[:, :, ::2, 1::2] + green[:, :, 1::2, ::2] = x[:, :, 1::2, ::2] + + return torch.cat( + ( + torch.nn.functional.interpolate( + red, size=(H, W), mode="bilinear", align_corners=False + ), + green, + torch.nn.functional.interpolate( + blue, size=(H, W), mode="bilinear", align_corners=False + ), + ), + dim=1, + ) + + +class Debayer5x5(torch.nn.Module): + """Demosaicing of Bayer images using Malver-He-Cutler algorithm. + + Requires BG-Bayer color filter array layout. That is, + the image[1,1]='B', image[1,2]='G'. This corresponds + to OpenCV naming conventions. + + Compared to Debayer2x2 this method does not use upsampling. + Compared to Debayer3x3 the algorithm gives sharper edges and + less chromatic effects. + + ## References + Malvar, Henrique S., Li-wei He, and Ross Cutler. + "High-quality linear interpolation for demosaicing of Bayer-patterned + color images." 2004 + """ + + def __init__(self, layout: Layout = Layout.RGGB): + super(Debayer5x5, self).__init__() + self.layout = layout + # fmt: off + self.kernels = torch.nn.Parameter( + torch.tensor( + [ + # G at R,B locations + # scaled by 16 + [ 0, 0, -2, 0, 0], # noqa + [ 0, 0, 4, 0, 0], # noqa + [-2, 4, 8, 4, -2], # noqa + [ 0, 0, 4, 0, 0], # noqa + [ 0, 0, -2, 0, 0], # noqa + + # R,B at G in R rows + # scaled by 16 + [ 0, 0, 1, 0, 0], # noqa + [ 0, -2, 0, -2, 0], # noqa + [-2, 8, 10, 8, -2], # noqa + [ 0, -2, 0, -2, 0], # noqa + [ 0, 0, 1, 0, 0], # noqa + + # R,B at G in B rows + # scaled by 16 + [ 0, 0, -2, 0, 0], # noqa + [ 0, -2, 8, -2, 0], # noqa + [ 1, 0, 10, 0, 1], # noqa + [ 0, -2, 8, -2, 0], # noqa + [ 0, 0, -2, 0, 0], # noqa + + # R at B and B at R + # scaled by 16 + [ 0, 0, -3, 0, 0], # noqa + [ 0, 4, 0, 4, 0], # noqa + [-3, 0, 12, 0, -3], # noqa + [ 0, 4, 0, 4, 0], # noqa + [ 0, 0, -3, 0, 0], # noqa + + # R at R, B at B, G at G + # identity kernel not shown + ] + ).view(4, 1, 5, 5).float() / 16.0, + requires_grad=False, + ) + # fmt: on + + self.index = torch.nn.Parameter( + # Below, note that index 4 corresponds to identity kernel + self._index_from_layout(layout), + requires_grad=False, + ) + + def forward(self, x): + """Debayer image. + + Parameters + ---------- + x : Bx1xHxW tensor + Images to debayer + + Returns + ------- + rgb : Bx3xHxW tensor + Color images in RGB channel order. + """ + B, C, H, W = x.shape + + xpad = torch.nn.functional.pad(x, (2, 2, 2, 2), mode="reflect") + planes = torch.nn.functional.conv2d(xpad, self.kernels, stride=1) + planes = torch.cat( + (planes, x), 1 + ) # Concat with input to give identity kernel Bx5xHxW + rgb = torch.gather( + planes, + 1, + self.index.repeat( + 1, + 1, + torch.div(H, 2, rounding_mode="floor"), + torch.div(W, 2, rounding_mode="floor"), + ).expand( + B, -1, -1, -1 + ), # expand for singleton batch dimension is faster + ) + return torch.clamp(rgb, 0, 1) + + def _index_from_layout(self, layout: Layout) -> torch.Tensor: + """Returns a 1x3x2x2 index tensor for each color RGB in a 2x2 bayer tile. + + Note, the index corresponding to the identity kernel is 4, which will be + correct after concatenating the convolved output with the input image. + """ + # ... + # ... b g b g ... + # ... g R G r ... + # ... b G B g ... + # ... g r g r ... + # ... + # fmt: off + rggb = torch.tensor( + [ + # dest channel r + [4, 1], # pixel is R,G1 + [2, 3], # pixel is G2,B + # dest channel g + [0, 4], # pixel is R,G1 + [4, 0], # pixel is G2,B + # dest channel b + [3, 2], # pixel is R,G1 + [1, 4], # pixel is G2,B + ] + ).view(1, 3, 2, 2) + # fmt: on + return { + Layout.RGGB: rggb, + Layout.GRBG: torch.roll(rggb, 1, -1), + Layout.GBRG: torch.roll(rggb, 1, -2), + Layout.BGGR: torch.roll(rggb, (1, 1), (-1, -2)), + }.get(layout) \ No newline at end of file diff --git a/src/training/DemosaicingDataset.py b/src/training/DemosaicingDataset.py new file mode 100644 index 0000000..dfed023 --- /dev/null +++ b/src/training/DemosaicingDataset.py @@ -0,0 +1,247 @@ +import pandas as pd +import os +from torch.utils.data import Dataset +import imageio +from colour_demosaicing import ( + mosaicing_CFA_Bayer +) +import numpy as np +import torch +import torch.nn.functional as F +from pathlib import Path +from RawHandler.RawHandler import RawHandler +from src.training.utils import cfa_to_sparse + +def random_crop_dim(shape, crop_size, buffer, validation=False): + """ + Calculates random (or centered) crop dimensions, ensuring even coordinates. + """ + h, w = shape + + # Ensure crop size is even + crop_size = (crop_size // 2) * 2 + + if not validation: + top = np.random.randint(0 + buffer, h - crop_size - buffer) + left = np.random.randint(0 + buffer, w - crop_size - buffer) + else: + top = (h - crop_size) // 2 + left = (w - crop_size) // 2 + + # Ensure top-left corner is even for correct Bayer pattern alignment + if top % 2 != 0: + top -= 1 + if left % 2 != 0: + left -= 1 + + # Handle potential boundary issues after adjustment + top = max(0, top) + left = max(0, left) + if top + crop_size > h: + top = h - crop_size + if left + crop_size > w: + left = w - crop_size + + bottom = top + crop_size + right = left + crop_size + return (left, right, top, bottom) + +# def cfa_to_sparse(cfa_image, pattern='RGGB'): +# """ +# Converts a 2D CFA image (H, W) to a 3D sparse RGB image (H, W, 3). +# """ +# H, W = cfa_image.shape +# sparse_rgb = np.zeros((H, W, 3), dtype=cfa_image.dtype) + +# # +# if pattern == 'RGGB': +# # R +# sparse_rgb[0::2, 0::2, 0] = cfa_image[0::2, 0::2] +# # G (top-right) +# sparse_rgb[0::2, 1::2, 1] = cfa_image[0::2, 1::2] +# # G (bottom-left) +# sparse_rgb[1::2, 0::2, 1] = cfa_image[1::2, 0::2] +# # B +# sparse_rgb[1::2, 1::2, 2] = cfa_image[1::2, 1::2] +# elif pattern == 'GRBG': +# sparse_rgb[0::2, 0::2, 1] = cfa_image[0::2, 0::2] +# sparse_rgb[0::2, 1::2, 0] = cfa_image[0::2, 1::2] +# sparse_rgb[1::2, 0::2, 2] = cfa_image[1::2, 0::2] +# sparse_rgb[1::2, 1::2, 1] = cfa_image[1::2, 1::2] +# elif pattern == 'GBRG': +# sparse_rgb[0::2, 0::2, 1] = cfa_image[0::2, 0::2] +# sparse_rgb[0::2, 1::2, 2] = cfa_image[0::2, 1::2] +# sparse_rgb[1::2, 0::2, 0] = cfa_image[1::2, 0::2] +# sparse_rgb[1::2, 1::2, 1] = cfa_image[1::2, 1::2] +# elif pattern == 'BGGR': +# sparse_rgb[0::2, 0::2, 2] = cfa_image[0::2, 0::2] +# sparse_rgb[0::2, 1::2, 1] = cfa_image[0::2, 1::2] +# sparse_rgb[1::2, 0::2, 1] = cfa_image[1::2, 0::2] +# sparse_rgb[1::2, 1::2, 0] = cfa_image[1::2, 1::2] +# else: +# raise NotImplementedError(f"Pattern {pattern} not implemented") + +# return sparse_rgb + +def cfa_to_rggb_stack(cfa_image, pattern='RGGB'): + """ + Converts a (H, W) CFA image to an (4, H/2, W/2) RGGB stack. + """ + assert cfa_image.ndim == 2, "Input must be (H, W)" + H, W = cfa_image.shape + assert H % 2 == 0 and W % 2 == 0, "Height and width must be even" + + if pattern == 'RGGB': + R = cfa_image[0::2, 0::2] + G1 = cfa_image[0::2, 1::2] # G at top-right + G2 = cfa_image[1::2, 0::2] # G at bottom-left + B = cfa_image[1::2, 1::2] + elif pattern == 'GRBG': + G1 = cfa_image[0::2, 0::2] + R = cfa_image[0::2, 1::2] + B = cfa_image[1::2, 0::2] + G2 = cfa_image[1::2, 1::2] + elif pattern == 'GBRG': + G1 = cfa_image[0::2, 0::2] + B = cfa_image[0::2, 1::2] + R = cfa_image[1::2, 0::2] + G2 = cfa_image[1::2, 1::2] + elif pattern == 'BGGR': + B = cfa_image[0::2, 0::2] + G1 = cfa_image[0::2, 1::2] + G2 = cfa_image[1::2, 0::2] + R = cfa_image[1::2, 1::2] + else: + raise NotImplementedError(f"Pattern {pattern} not implemented") + + # Stack R, G1, G2, B + rggb_stack = np.stack([R, G1, G2, B], axis=0) + return rggb_stack + +def pixel_unshuffle(x, r): + C, H, W = x.shape + x = ( + x.reshape(C, H // r, r, W // r, r) + .transpose(0, 2, 4, 1, 3) + .reshape(C * r**2, H // r, W // r) + ) + return x + +class DemosaicingDataset(Dataset): + """ + Dataset for learned demosaicing. + + Workflow: + 1. Load High-Res Noisy DNG. + 2. Crop a large patch, get its RGB representation (e.g., from RawHandler). + 3. Downsample this RGB patch (area) to create the Ground Truth image. + 4. Apply Bayer mosaicing to the Ground Truth to create the network input. + 5. Provide input in sparse (3-ch) and RGGB (4-ch) formats. + """ + def __init__(self, path, csv, colorspace, + output_crop_size=256, + downsample_factor=2, + buffer=10, + validation=False, + bayer_pattern='RGGB'): + super().__init__() + self.df = pd.read_csv(csv) + self.path = Path(path) + self.output_crop_size = (output_crop_size // 2) * 2 # Ensure even + self.downsample_factor = downsample_factor + self.input_crop_size = self.output_crop_size * self.downsample_factor + self.buffer = buffer + self.coordinate_iso = 6400.0 # Normalization constant for ISO + self.validation = validation + self.dtype = np.float32 # Use float32 for tensors + self.colorspace = colorspace + self.bayer_pattern = bayer_pattern + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + + try: + name = Path(f"{row.bayer_path}").name + name = str(self.path / name.replace('_bayer.jpg', '.dng')) + noisy_rh = RawHandler(name) + except Exception as e: + print(f"Error loading {name}: {e}") + return self.__getitem__((idx + 1) % len(self)) # Skip bad file + + try: + dims = random_crop_dim( + noisy_rh.raw.shape, + self.input_crop_size, + self.buffer, + validation=self.validation + ) + + # Get (3, H_in, W_in) RGB patch from RawHandler + noisy_rgb_full = noisy_rh.as_rgb( + dims=dims, + colorspace=self.colorspace + ).astype(self.dtype) + + # 3. Downsample to create Ground Truth + # Convert to (1, 3, H_in, W_in) tensor for interpolation + noisy_tensor = torch.from_numpy(noisy_rgb_full).unsqueeze(0) + + gt_tensor = F.interpolate( + noisy_tensor, + scale_factor=1.0/self.downsample_factor, + mode='area', + recompute_scale_factor=False # Use scale_factor directly + ) + + # Back to (3, H_out, W_out) and clip + gt_tensor = gt_tensor.squeeze(0).clip(0, 1) + + # 4. Apply Bayer mosaicing to create network input + # Convert GT to (H_out, W_out, 3) numpy array + gt_numpy = gt_tensor.permute(1, 2, 0).numpy() + + # Create (H_out, W_out) 2D CFA + input_cfa_hw = mosaicing_CFA_Bayer( + gt_numpy, + pattern=self.bayer_pattern + ) + + # 5. Provide input in sparse and RGGB formats + + # Create (H_out, W_out, 3) sparse array + input_sparse_chw = cfa_to_sparse( + input_cfa_hw, + pattern=self.bayer_pattern + )[0] + # Convert to (3, H_out, W_out) tensor + input_sparse_chw = torch.from_numpy( + input_sparse_chw + ).to(torch.float32) + + + # Create (4, H_out/2, W_out/2) RGGB stack + input_cfa_hw_expanded = np.expand_dims(input_cfa_hw, 0) + input_rggb_stack = pixel_unshuffle(input_cfa_hw_expanded, 2) + + # Convert to (4, H_out/2, W_out/2) tensor + input_rggb_tensor = torch.from_numpy(input_rggb_stack).to(torch.float32) + + # 6. Conditioning tensor + conditioning = torch.tensor( + [row.iso / self.coordinate_iso] + ).to(torch.float32) + + output = { + "ground_truth": gt_tensor, # (3, H_out, W_out) + "cfa_sparse": input_sparse_chw, # (3, H_out, W_out) + "cfa_rggb": input_rggb_tensor, # (4, H_out/2, W_out/2) + "conditioning": conditioning # (1,) + } + return output + + except Exception as e: + print(f"Error processing {name} (idx {idx}): {e}") + return self.__getitem__((idx + 1) % len(self)) # Skip bad file diff --git a/src/training/RawDataset.py b/src/training/RawDataset.py new file mode 100644 index 0000000..49d7a1b --- /dev/null +++ b/src/training/RawDataset.py @@ -0,0 +1,155 @@ +import torch +from torch.utils.data import Dataset +import pandas as pd +import numpy as np +from matplotlib.colors import rgb_to_hsv, hsv_to_rgb +import random +import cv2 +from pathlib import Path +from RawHandler.RawHandler import RawHandler +from RawHandler.utils import linear_to_srgb, pixel_unshuffle, pixel_shuffle +from colour_demosaicing import demosaicing_CFA_Bayer_Malvar2004 +from src.training.image_utils import simulate_sparse, bilinear_demosaic + +def normalized_cross_correlation(im1, im2): + im1 = im1 - np.mean(im1) + im2 = im2 - np.mean(im2) + numerator = np.sum(im1 * im2) + denominator = np.sqrt(np.sum(im1**2) * np.sum(im2**2)) + return numerator / denominator if denominator != 0 else 0.0 + +class RawDataset(Dataset): + def __init__(self, csv_path, patch_size=256, cc_threshold=0.92, colorspace='lin_rec2020'): + self.df = pd.read_csv(csv_path) + + # Drop rows with missing 'cc' or invalid types + self.df = self.df[pd.to_numeric(self.df['cc'], errors='coerce').notnull()] + self.df['cc'] = self.df['cc'].astype(float) + + # Filter based on cc threshold + original_len = len(self.df) + self.df = self.df[self.df['cc'] >= cc_threshold] + filtered_len = len(self.df) + + print(f"[Dataset] Filtered out {original_len - filtered_len} rows below cc threshold of {cc_threshold}.") + + self.patch_size = patch_size + self.colorspace = colorspace + + def __len__(self): + return len(self.df) + + + def __getitem__(self, idx): + row = self.df.iloc[idx] + + # Load images using RawHandler + gt = RawHandler(row["gt_image"], colorspace=self.colorspace) + + # Get dimensions + H, W = gt.raw.shape # Assume same for noisy and gt + + # Random crop coordinates + if H < self.patch_size or W < self.patch_size: + raise ValueError(f"Image is smaller than patch size: {H}x{W} < {self.patch_size}") + + align_offset = 20 + x1 = random.randint(0 + align_offset, W - (self.patch_size + align_offset)) + y1 = random.randint(0 + align_offset, H - (self.patch_size + align_offset)) + return self.get_patches(idx, x1, y1) + + + def get_patches(self, idx, x1, y1): + row = self.df.iloc[idx] + + # Load images using RawHandler + noisy = RawHandler(row["noisy_image"], colorspace=self.colorspace) + gt = RawHandler(row["gt_image"], colorspace=self.colorspace) + + # Get dimensions + H, W = gt.raw.shape # Assume same for noisy and gt + + # Random crop coordinates + if H < self.patch_size or W < self.patch_size: + raise ValueError(f"Image is smaller than patch size: {H}x{W} < {self.patch_size}") + + align_offset = 20 + x2 = x1 + self.patch_size + y2 = y1 + self.patch_size + crop_dim = (y1, y2, x1, x2) + + # Convert to float32 numpy arrays + expand_crop_dim = (y1-align_offset, y2+align_offset, x1-align_offset, x2+align_offset) + noisy_patch = noisy.as_rgb(dims=expand_crop_dim, demosaicing_func=demosaicing_CFA_Bayer_Malvar2004).astype(np.float32) + noisy_patch = noisy_patch[:, align_offset:-align_offset, align_offset:-align_offset] + + align_crop_dim = (y1-align_offset, y2+align_offset, x1-align_offset, x2+align_offset) + gt_patch = gt.as_rgb(dims=align_crop_dim, demosaicing_func=demosaicing_CFA_Bayer_Malvar2004).astype(np.float32) + + # Align gt image + gt_image = gt_patch.transpose(1, 2, 0) + noisy_image = noisy_patch.transpose(1, 2, 0) + + warp_matrix = np.array([[1, 0, -row.x_warp + align_offset], + [0, 1, -row.y_warp + align_offset]]) + gt_image = cv2.warpAffine(gt_image, warp_matrix, + (noisy_image.shape[1], noisy_image.shape[0]), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP) + + ncc = normalized_cross_correlation(gt_image, noisy_image) + noise_level = (gt_image-noisy_image).std(axis=(0, 1)) + + gt_patch = gt_image.transpose(2, 0, 1) + + # Adjust brightness + seperate_channels = False + if seperate_channels: + gains = gt_patch.mean(axis=(1, 2))/noisy_patch.mean(axis=(1, 2)) + + # Known issue, since the gt will clip at 1 and adjust by 2, we might have a max val of .5 + gt_patch *= 1/gains.reshape(3, 1, 1) + gt_patch = np.clip(gt_patch, 0, 1) + else: + gains = gt_patch.mean()/noisy_patch.mean() + gt_patch *= 1/gains + gt_patch = np.clip(gt_patch, 0, 1) + + # Make sparse and bilinear output + noisy_sparse = noisy.as_sparse(dims=expand_crop_dim).astype(np.float32) + noisy_sparse = noisy_sparse[:, align_offset:-align_offset, align_offset:-align_offset] + gt_sparse, g_sparse_mask = simulate_sparse(gt_patch) + bilinear = bilinear_demosaic(noisy_sparse) + noisy_sparse = torch.from_numpy(noisy_sparse).float() + noisy_tensor = torch.from_numpy(noisy_patch).float() + bilinear_tensor = torch.from_numpy(bilinear).float() + gt_tensor = torch.from_numpy(gt_patch).float() + gt_sparse_tensor = torch.from_numpy(gt_sparse).float() + + return_dict = { + 'noisy_sparse': noisy_sparse, + 'noisy_tensor': noisy_tensor, + 'bilinear_tensor': bilinear_tensor, + 'gt': gt_tensor, + 'gt_sparse': gt_sparse_tensor, + # "noisy_rggb_tensor": noisy_rggb_tensor, + # "conditioning": conditioning, + 'idx': idx, + 'cc': row['cc'], + 'ncc': ncc, + 'noise_level': noise_level, + 'iso': row['iso'], + # 'x_warp': row['x_warp'], + # 'y_warp': row['y_warp'], + # 'gt_mean': row['gt_mean'], + # 'noisy_mean': row['noisy_mean'], + 'gains': gains, + # 'noisy_rggb_tensor': noisy_rggb_tensor + } + + using_NAF = True + if using_NAF: + noisy_rggb = noisy.as_rggb( dims=crop_dim) + noisy_rggb_tensor = torch.from_numpy(noisy_rggb).float() + conditioning = torch.tensor([float(row['iso'])/6400, 0, 0, 0]).float() + return_dict['conditioning'] = conditioning + return_dict['noisy_rggb_tensor'] = noisy_rggb_tensor + return return_dict diff --git a/src/training/RawDatasetDNG.py b/src/training/RawDatasetDNG.py new file mode 100644 index 0000000..6c9a162 --- /dev/null +++ b/src/training/RawDatasetDNG.py @@ -0,0 +1,147 @@ +import pandas as pd +import os +from torch.utils.data import Dataset +import imageio +from colour_demosaicing import ( + ROOT_RESOURCES_EXAMPLES, + demosaicing_CFA_Bayer_bilinear, + demosaicing_CFA_Bayer_Malvar2004, + demosaicing_CFA_Bayer_Menon2007, + mosaicing_CFA_Bayer) + +from src.training.utils import inverse_gamma_tone_curve, cfa_to_sparse +import numpy as np +import torch +from src.training.align_images import apply_alignment, align_clean_to_noisy +from pathlib import Path +from RawHandler.RawHandler import RawHandler + + + +def global_affine_match(A, D, mask=None): + """ + Fit D ≈ a + b*A with least squares. + A, D : 2D arrays, same shape (linear values) + mask : optional boolean array, True=use pixel + returns: a, b, D_pred, D_resid (D - (a + b*A)) + """ + A = A.ravel().astype(np.float64) + D = D.ravel().astype(np.float64) + if mask is None: + mask = np.isfinite(A) & np.isfinite(D) + else: + mask = mask.ravel() & np.isfinite(A) & np.isfinite(D) + + A0 = A[mask] + D0 = D[mask] + # design matrix [1, A] + X = np.vstack([np.ones_like(A0), A0]).T + coef, *_ = np.linalg.lstsq(X, D0, rcond=None) + a, b = coef[0], coef[1] + D_pred = (a + b * A).reshape(-1) + D_pred = D_pred.reshape(A.shape) if False else (a + b * A).reshape((-1,)) # keep flatten + + return a, b, (a + b * A) + + +def random_crop_dim(shape, crop_size, buffer, validation=False): + h, w = shape + if not validation: + top = np.random.randint(0 + buffer, h - crop_size - buffer) + left = np.random.randint(0 + buffer, w - crop_size - buffer) + else: + top = (h - crop_size) // 2 + left = (w - crop_size) // 2 + + if top % 2 != 0: top = top - 1 + if left % 2 != 0: left = left - 1 + bottom = top + crop_size + right = left + crop_size + return (left, right, top, bottom) + +def check_align_matrix(row, tolerance=1e-7): + is_identity = np.isclose(row['M00'], 1.0, atol=tolerance) and \ + np.isclose(row['M01'], 0.0, atol=tolerance) and \ + np.isclose(row['M10'], 0.0, atol=tolerance) and \ + np.isclose(row['M11'], 1.0, atol=tolerance) + + assert is_identity, "Rotations, scalings, or shearing are not tested." + + +def round_to_nearest_2(number): + return round(number / 2) * 2 + + +class RawDatasetDNG(Dataset): + def __init__(self, path, csv, colorspace, crop_size=180, buffer=10, validation=False, run_align=False, dimensions=2000): + super().__init__() + self.df = pd.read_csv(csv) + self.path = path + self.crop_size = crop_size + self.buffer = buffer + self.coordinate_iso = 6400 + self.validation=validation + self.run_align = run_align + self.dtype = np.float16 + self.dimensions = dimensions + self.colorspace = colorspace + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + # Load images + try: + name = Path(f"{row.bayer_path}").name + name = str(self.path / name.replace('_bayer.jpg', '.dng')) + noisy_rh = RawHandler(name) + except: + print(name) + + try: + gt_name = Path(f"{row.gt_path}").name + gt_name = str(self.path / gt_name.replace('.jpg', '.dng')) + gt_rh = RawHandler(gt_name) + except: + print(gt_name) + + dims = random_crop_dim(noisy_rh.raw.shape, self.crop_size, self.buffer, validation=self.validation) + try: + bayer_data = noisy_rh.apply_colorspace_transform(dims=dims, colorspace=self.colorspace) + noisy = noisy_rh.as_rgb(dims=dims, colorspace=self.colorspace) + rggb = noisy_rh.as_rggb(dims=dims, colorspace=self.colorspace) + sparse = noisy_rh.as_sparse(dims=dims, colorspace=self.colorspace) + + check_align_matrix(row) + expanded_dims = [dims[0]-self.buffer, dims[1]+self.buffer, dims[2]-self.buffer, dims[3]+self.buffer] + gt_expanded = gt_rh.as_rgb(dims=expanded_dims, colorspace=self.colorspace) + aligned = apply_alignment(gt_expanded.transpose(1, 2, 0), row.to_dict())[self.buffer:-self.buffer, self.buffer:-self.buffer] + gt_non_aligned = gt_expanded.transpose(1, 2, 0)[self.buffer:-self.buffer, self.buffer:-self.buffer] + # # gt_non_aligned = gt_non_aligned * noisy.mean() / aligned.mean() + # # aligned = aligned * noisy.mean() / aligned.mean() + # # Get Raw data for testing + # noisy_raw = noisy_rh.raw[dims[0]:dims[1], dims[2]: dims[3]] + # row_dict = row.to_dict() + # shift_y, shift_x = round_to_nearest_2(row_dict['M12']), round_to_nearest_2(row_dict['M11']) + # gt_raw = gt_rh.raw[dims[0]+shift_y:dims[1]+shift_y, dims[2]+shift_x:dims[3]+shift_x] + except: + print(name, gt_name) + + # aligned = gt_rh.as_rgb(dims=dims, colorspace=self.colorspace).transpose(1, 2, 0) + + # Convert to tensors + output = { + "bayer": torch.tensor(bayer_data).to(float).clip(0,1), + "gt_non_aligned": torch.tensor(gt_non_aligned).to(float).permute(2, 0, 1).clip(0,1), + "aligned": torch.tensor(aligned).to(float).permute(2, 0, 1).clip(0,1), + "sparse": torch.tensor(sparse).to(float).clip(0,1), + "noisy": torch.tensor(noisy).to(float).clip(0,1), + "rggb": torch.tensor(rggb).to(float).clip(0,1), + "conditioning": torch.tensor([row.iso/self.coordinate_iso]).to(float), + # "noisy_raw": noisy_raw, + # "gt_raw": gt_raw, + # "noise_est": noise_est, + # "rggb_gt": rggb_gt, + } + return output \ No newline at end of file diff --git a/src/training/RawDatasetDNG_load_into_memory.py b/src/training/RawDatasetDNG_load_into_memory.py new file mode 100644 index 0000000..bafd62f --- /dev/null +++ b/src/training/RawDatasetDNG_load_into_memory.py @@ -0,0 +1,156 @@ +import pandas as pd +import os +from torch.utils.data import Dataset +import imageio +from colour_demosaicing import ( + ROOT_RESOURCES_EXAMPLES, + demosaicing_CFA_Bayer_bilinear, + demosaicing_CFA_Bayer_Malvar2004, + demosaicing_CFA_Bayer_Menon2007, + mosaicing_CFA_Bayer) + +# from src.training.utils import inverse_gamma_tone_curve, cfa_to_sparse +import numpy as np +import torch +# from src.training.align_images import apply_alignment, align_clean_to_noisy +from pathlib import Path +from RawHandler.RawHandler import RawHandler + +class RawDatasetDNG(Dataset): + def __init__(self, path, csv, colorspace, crop_size=180, buffer=10, + validation=False, run_align=False, + dimensions=2000, + apply_exposure_corr=True, + demosaicing_func = demosaicing_CFA_Bayer_Malvar2004): + super().__init__() + self.df = pd.read_csv(csv) + self.path = path + self.crop_size = crop_size + self.buffer = buffer + self.coordinate_iso = 6400 + self.validation=validation + self.run_align = run_align + self.dtype = np.float16 + self.dimensions = dimensions + self.colorspace = colorspace + self.apply_exposure_corr = apply_exposure_corr + self.demosaicing_func = demosaicing_func + + files = os.listdir(path) + files = [f for f in files if 'dng' in f] + files = [f for f in files if not 'xmp' in f] + self.rhs = {} + for file in files: + self.rhs[file] = RawHandler(f'Cropped_Raw/{file}') + + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + # Load images + name = Path(f"{row.bayer_path}").name + name = name.replace('_bayer.jpg', '.dng') + noisy_rh = self.rhs[name] + + gt_name = Path(f"{row.gt_path}").name + gt_name = gt_name.replace('.jpg', '.dng') + gt_rh = self.rhs[gt_name] + + dims = random_crop_dim(noisy_rh.raw.shape, self.crop_size, self.buffer, validation=self.validation) + + bayer_data = noisy_rh.apply_colorspace_transform(dims=dims, colorspace=self.colorspace) + noisy = noisy_rh.as_rgb(dims=dims, colorspace=self.colorspace, demosaicing_func=self.demosaicing_func, clip=False) + rggb = noisy_rh.as_rggb(dims=dims, colorspace=self.colorspace, clip=False) + sparse = noisy_rh.as_sparse(dims=dims, colorspace=self.colorspace, clip=False) + + check_align_matrix(row) + expanded_dims = [dims[0]-self.buffer, dims[1]+self.buffer, dims[2]-self.buffer, dims[3]+self.buffer] + gt_expanded = gt_rh.as_rgb(dims=expanded_dims, colorspace=self.colorspace, demosaicing_func=self.demosaicing_func, clip=False) + if self.apply_exposure_corr: + gt_expanded[0] *= row['r_scale_factor'] + gt_expanded[1] *= row['g_scale_factor'] + gt_expanded[2] *= row['b_scale_factor'] + aligned = apply_alignment(gt_expanded.transpose(1, 2, 0), row.to_dict())[self.buffer:-self.buffer, self.buffer:-self.buffer] + gt_non_aligned = gt_expanded.transpose(1, 2, 0)[self.buffer:-self.buffer, self.buffer:-self.buffer] + # # gt_non_aligned = gt_non_aligned * noisy.mean() / aligned.mean() + # # aligned = aligned * noisy.mean() / aligned.mean() + # # Get Raw data for testing + # noisy_raw = noisy_rh.raw[dims[0]:dims[1], dims[2]: dims[3]] + # row_dict = row.to_dict() + # shift_y, shift_x = round_to_nearest_2(row_dict['M12']), round_to_nearest_2(row_dict['M11']) + # gt_raw = gt_rh.raw[dims[0]+shift_y:dims[1]+shift_y, dims[2]+shift_x:dims[3]+shift_x] + # aligned = gt_rh.as_rgb(dims=dims, colorspace=self.colorspace).transpose(1, 2, 0) + + # Convert to tensors + output = { + "bayer": torch.tensor(bayer_data).to(float).clip(0,1), + "gt_non_aligned": torch.tensor(gt_non_aligned).to(float).permute(2, 0, 1).clip(0,1), + "aligned": torch.tensor(aligned).to(float).permute(2, 0, 1).clip(0,1), + "sparse": torch.tensor(sparse).to(float).clip(0,1), + "noisy": torch.tensor(noisy).to(float).clip(0,1), + "rggb": torch.tensor(rggb).to(float).clip(0,1), + "conditioning": torch.tensor([row.iso/self.coordinate_iso]).to(float), + # "noisy_raw": noisy_raw, + # "gt_raw": gt_raw, + # "noise_est": noise_est, + # "rggb_gt": rggb_gt, + } + return output + + + + +def global_affine_match(A, D, mask=None): + """ + Fit D ≈ a + b*A with least squares. + A, D : 2D arrays, same shape (linear values) + mask : optional boolean array, True=use pixel + returns: a, b, D_pred, D_resid (D - (a + b*A)) + """ + A = A.ravel().astype(np.float64) + D = D.ravel().astype(np.float64) + if mask is None: + mask = np.isfinite(A) & np.isfinite(D) + else: + mask = mask.ravel() & np.isfinite(A) & np.isfinite(D) + + A0 = A[mask] + D0 = D[mask] + # design matrix [1, A] + X = np.vstack([np.ones_like(A0), A0]).T + coef, *_ = np.linalg.lstsq(X, D0, rcond=None) + a, b = coef[0], coef[1] + D_pred = (a + b * A).reshape(-1) + D_pred = D_pred.reshape(A.shape) if False else (a + b * A).reshape((-1,)) # keep flatten + + return a, b, (a + b * A) + + +def random_crop_dim(shape, crop_size, buffer, validation=False): + h, w = shape + if not validation: + top = np.random.randint(0 + buffer, h - crop_size - buffer) + left = np.random.randint(0 + buffer, w - crop_size - buffer) + else: + top = (h - crop_size) // 2 + left = (w - crop_size) // 2 + + if top % 2 != 0: top = top - 1 + if left % 2 != 0: left = left - 1 + bottom = top + crop_size + right = left + crop_size + return (left, right, top, bottom) + +def check_align_matrix(row, tolerance=1e-7): + is_identity = np.isclose(row['M00'], 1.0, atol=tolerance) and \ + np.isclose(row['M01'], 0.0, atol=tolerance) and \ + np.isclose(row['M10'], 0.0, atol=tolerance) and \ + np.isclose(row['M11'], 1.0, atol=tolerance) + + assert is_identity, "Rotations, scalings, or shearing are not tested." + + +def round_to_nearest_2(number): + return round(number / 2) * 2 \ No newline at end of file diff --git a/src/training/SmallRawDataset.py b/src/training/SmallRawDataset.py new file mode 100644 index 0000000..50fd962 --- /dev/null +++ b/src/training/SmallRawDataset.py @@ -0,0 +1,93 @@ +import pandas as pd +import os +from torch.utils.data import Dataset +import imageio +from colour_demosaicing import ( + ROOT_RESOURCES_EXAMPLES, + demosaicing_CFA_Bayer_bilinear, + demosaicing_CFA_Bayer_Malvar2004, + demosaicing_CFA_Bayer_Menon2007, + mosaicing_CFA_Bayer) + +from src.training.utils import inverse_gamma_tone_curve, cfa_to_sparse +import numpy as np +import torch +from src.training.align_images import apply_alignment, align_clean_to_noisy +from pathlib import Path + +class SmallRawDataset(Dataset): + def __init__(self, path, csv, crop_size=180, buffer=10, validation=False, run_align=False): + super().__init__() + self.df = pd.read_csv(csv) + self.path = path + self.crop_size = crop_size + self.buffer = buffer + self.coordinate_iso = 6400 + self.validation=validation + self.run_align = run_align + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + # Load images + with imageio.imopen(self.path / Path(f"{row.bayer_path}").name, "r") as image_resource: + bayer_data = image_resource.read() + + with imageio.imopen(self.path / Path(f"{row.gt_path}").name, "r") as image_resource: + gt_image = image_resource.read() + gt_image = gt_image/255 + bayer_data = bayer_data/255 + demosaiced_noisy = demosaicing_CFA_Bayer_Malvar2004(bayer_data) + + if self.run_align: + gt_image = (gt_image * 255).astype(np.uint8) + demosaiced_noisy = (demosaiced_noisy * 255).astype(np.uint8) + aligned, _, _ = align_clean_to_noisy(gt_image, demosaiced_noisy, refine=False, verbose=False) + aligned = aligned / 255 + else: + aligned = apply_alignment(gt_image, row.to_dict()) + + h, w, _ = gt_image.shape + + #Crop images + if not self.validation: + top = np.random.randint(0 + self.buffer, h - self.crop_size - self.buffer) + left = np.random.randint(0 + self.buffer, w - self.crop_size - self.buffer) + else: + top = (h - self.crop_size) // 2 + left = (w - self.crop_size) // 2 + + if top % 2 != 0: top = top - 1 + if left % 2 != 0: left = left - 1 + bottom = top + self.crop_size + right = left + self.crop_size + aligned = aligned[top:bottom, left:right] + gt_image = gt_image[top:bottom, left:right] + bayer_data = bayer_data[top:bottom, left:right] + h, w, _ = gt_image.shape + + # Translate to linear + gt_image = inverse_gamma_tone_curve(gt_image) + aligned = inverse_gamma_tone_curve(aligned) + bayer_data = inverse_gamma_tone_curve(bayer_data) + + demosaiced_noisy = demosaicing_CFA_Bayer_Malvar2004(bayer_data) + + aligned = aligned * demosaiced_noisy.mean() / aligned.mean() + gt_image = gt_image * demosaiced_noisy.mean() / gt_image.mean() + + sparse, _ = cfa_to_sparse(bayer_data) + rggb =bayer_data.reshape(h // 2, 2, w // 2, 2, 1).transpose(1, 3, 4, 0, 2).reshape(4, h // 2, w // 2) + + # Convert to tensors + output = { + "bayer": torch.tensor(bayer_data).to(float).clip(0,1), + "gt": torch.tensor(gt_image).to(float).permute(2, 0, 1).clip(0,1), + "aligned": torch.tensor(aligned).to(float).permute(2, 0, 1).clip(0,1), + "sparse": torch.tensor(sparse).to(float).clip(0,1), + "noisy": torch.tensor(demosaiced_noisy).to(float).permute(2, 0, 1).clip(0,1), + "rggb": torch.tensor(rggb).to(float).clip(0,1), + "conditioning": torch.tensor([row.iso/self.coordinate_iso]).to(float), + } + return output \ No newline at end of file diff --git a/src/training/SmallRawDatasetNumpy.py b/src/training/SmallRawDatasetNumpy.py new file mode 100644 index 0000000..15dd614 --- /dev/null +++ b/src/training/SmallRawDatasetNumpy.py @@ -0,0 +1,136 @@ +import pandas as pd +import os +from torch.utils.data import Dataset +import imageio +from colour_demosaicing import ( + ROOT_RESOURCES_EXAMPLES, + demosaicing_CFA_Bayer_bilinear, + demosaicing_CFA_Bayer_Malvar2004, + demosaicing_CFA_Bayer_Menon2007, + mosaicing_CFA_Bayer) + +from src.training.utils import inverse_gamma_tone_curve, cfa_to_sparse +import numpy as np +import torch +from src.training.align_images import apply_alignment, align_clean_to_noisy +from pathlib import Path + + + +def global_affine_match(A, D, mask=None): + """ + Fit D ≈ a + b*A with least squares. + A, D : 2D arrays, same shape (linear values) + mask : optional boolean array, True=use pixel + returns: a, b, D_pred, D_resid (D - (a + b*A)) + """ + A = A.ravel().astype(np.float64) + D = D.ravel().astype(np.float64) + if mask is None: + mask = np.isfinite(A) & np.isfinite(D) + else: + mask = mask.ravel() & np.isfinite(A) & np.isfinite(D) + + A0 = A[mask] + D0 = D[mask] + # design matrix [1, A] + X = np.vstack([np.ones_like(A0), A0]).T + coef, *_ = np.linalg.lstsq(X, D0, rcond=None) + a, b = coef[0], coef[1] + D_pred = (a + b * A).reshape(-1) + D_pred = D_pred.reshape(A.shape) if False else (a + b * A).reshape((-1,)) # keep flatten + + return a, b, (a + b * A) + +class SmallRawDatasetNumpy(Dataset): + def __init__(self, path, csv, crop_size=180, buffer=10, validation=False, run_align=False, dimensions=2000): + super().__init__() + self.df = pd.read_csv(csv) + self.path = path + self.crop_size = crop_size + self.buffer = buffer + self.coordinate_iso = 6400 + self.validation=validation + self.run_align = run_align + self.dtype = np.float16 + self.dimensions = dimensions + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + # Load images + + name = Path(f"{row.bayer_path}").name + name = name.replace('_bayer.jpg', '.u16.raw') + bayer_data = np.fromfile(self.path / name, dtype=self.dtype) + bayer_data = bayer_data.reshape((self.dimensions, self.dimensions)) + + name = Path(f"{row.gt_path}").name + name = name.replace('jpg', 'u16.raw') + gt_image = np.fromfile(self.path / name, dtype=self.dtype) + gt_image = gt_image.reshape((self.dimensions, self.dimensions)) + + + + gt_image = gt_image + bayer_data = bayer_data + gt_image = demosaicing_CFA_Bayer_Malvar2004(gt_image) + demosaiced_noisy = demosaicing_CFA_Bayer_Malvar2004(bayer_data) + + if self.run_align: + gt_image = (gt_image * 255).astype(np.uint8) + demosaiced_noisy = (demosaiced_noisy * 255).astype(np.uint8) + aligned, _, _ = align_clean_to_noisy(gt_image, demosaiced_noisy, refine=False, verbose=False) + aligned = aligned / 255 + else: + aligned = apply_alignment(gt_image, row.to_dict()) + h, w, _ = gt_image.shape + + #Crop images + if not self.validation: + top = np.random.randint(0 + self.buffer, h - self.crop_size - self.buffer) + left = np.random.randint(0 + self.buffer, w - self.crop_size - self.buffer) + else: + top = (h - self.crop_size) // 2 + left = (w - self.crop_size) // 2 + + if top % 2 != 0: top = top - 1 + if left % 2 != 0: left = left - 1 + bottom = top + self.crop_size + right = left + self.crop_size + aligned = aligned[top:bottom, left:right] + gt_image = gt_image[top:bottom, left:right] + bayer_data = bayer_data[top:bottom, left:right] + h, w, _ = gt_image.shape + + demosaiced_noisy = demosaicing_CFA_Bayer_Malvar2004(bayer_data) + sparse, _ = cfa_to_sparse(bayer_data) + rggb = bayer_data.reshape(h // 2, 2, w // 2, 2, 1).transpose(1, 3, 4, 0, 2).reshape(4, h // 2, w // 2) + + # # Affine transform to match brightness in gt to noisy + # a, b, aligned = global_affine_match(aligned, demosaiced_noisy) + + # print(a, b, aligned.shape) + # gt_image = gt_image * demosaiced_noisy.mean() / gt_image.mean() + # aligned = aligned * demosaiced_noisy.mean() / aligned.mean() + + + # Sim noise method + mosaiced_gt = mosaicing_CFA_Bayer(aligned) + rggb_gt = mosaiced_gt.reshape(h // 2, 2, w // 2, 2, 1).transpose(1, 3, 4, 0, 2).reshape(4, h // 2, w // 2) + noise_est = rggb - rggb_gt + + # Convert to tensors + output = { + "bayer": torch.tensor(bayer_data).to(float).clip(0,1), + "gt_non_aligned": torch.tensor(gt_image).to(float).permute(2, 0, 1).clip(0,1), + "aligned": torch.tensor(aligned).to(float).permute(2, 0, 1).clip(0,1), + "sparse": torch.tensor(sparse).to(float).clip(0,1), + "noisy": torch.tensor(demosaiced_noisy).to(float).permute(2, 0, 1).clip(0,1), + "rggb": torch.tensor(rggb).to(float).clip(0,1), + "conditioning": torch.tensor([row.iso/self.coordinate_iso]).to(float), + "noise_est": noise_est, + "rggb_gt": rggb_gt, + } + return output \ No newline at end of file diff --git a/src/training/VGGFeatureExtractor.py b/src/training/VGGFeatureExtractor.py new file mode 100644 index 0000000..d9a01b4 --- /dev/null +++ b/src/training/VGGFeatureExtractor.py @@ -0,0 +1,81 @@ +import torch.nn as nn +import torch +import torch.nn.functional as F +import pytorch_msssim +import math + +def custom_weight_init(m): + if isinstance(m, nn.Conv2d): + k = m.kernel_size[0] # assuming square kernels + c_in = m.in_channels + n_l = (k ** 2) * c_in + std = math.sqrt(2.0 / n_l) + nn.init.normal_(m.weight, mean=0.0, std=std) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.Linear): + n_l = m.in_features + std = math.sqrt(2.0 / n_l) + nn.init.normal_(m.weight, mean=0.0, std=std) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + + +class VGGFeatureExtractor(nn.Module): + """ + VGG-like network for perceptual loss with runtime hot-swappable activations. + Returns features from selected layers. + """ + def __init__(self, config=None, feature_layers=None, activation=None): + super().__init__() + if config is None: + config = [(2, 64), (2, 128), (3, 256), (3, 512), (3, 512)] + + if feature_layers is None: + feature_layers = [3, 8, 15, 22, 29] + + if activation is None: + activation = lambda: nn.ReLU(inplace=False) + + layers = [] + in_channels = 3 + for num_convs, out_channels in config: + for _ in range(num_convs): + layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)) + layers.append(activation()) + in_channels = out_channels + layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) + + self.features = nn.Sequential(*layers) + self.feature_layers = feature_layers + self.activation_factory = activation # store for later swapping + self.apply(custom_weight_init) + + def set_activation(self, activation_cls, **kwargs): + """ + Replace all activation layers with new activation. + activation_cls: activation class (e.g., nn.GELU, nn.LeakyReLU) + kwargs: any keyword args for the activation class + """ + for i, layer in enumerate(self.features): + if isinstance(layer, nn.ReLU) or \ + isinstance(layer, nn.LeakyReLU) or \ + isinstance(layer, nn.GELU) or \ + isinstance(layer, nn.SiLU) or \ + isinstance(layer, nn.Identity) or \ + isinstance(layer, nn.Tanh): + self.features[i] = activation_cls(**kwargs) + # Update the stored factory for future reference + self.activation_factory = lambda: activation_cls(**kwargs) + + def forward(self, x): + outputs = [] + for i, layer in enumerate(self.features): + x = layer(x) + if i in self.feature_layers: + outputs.append(x) + return outputs + + + + diff --git a/src/training/align_images.py b/src/training/align_images.py new file mode 100644 index 0000000..ab0a56a --- /dev/null +++ b/src/training/align_images.py @@ -0,0 +1,179 @@ +import cv2 +import numpy as np +from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr +import pandas as pd +import os +from torch.utils.data import Dataset +import imageio +from colour_demosaicing import ( + ROOT_RESOURCES_EXAMPLES, + demosaicing_CFA_Bayer_bilinear, + demosaicing_CFA_Bayer_Malvar2004, + demosaicing_CFA_Bayer_Menon2007, + mosaicing_CFA_Bayer) + +import numpy as np +import torch +import cv2 +from pathlib import Path + + + + +def align_clean_to_noisy(clean_img, noisy_img, refine=True, verbose=False): + """ + Aligns the clean image to the noisy image and returns: + - aligned image + - best warp matrix (2x3 affine) + - metrics dict (PSNR/SSIM before and after alignment) + """ + + # --- convert to grayscale float32 for processing --- + def to_gray_f(img): + g = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img.copy() + g = g.astype(np.float32) + g = (g - g.mean()) / (g.std() + 1e-8) + return g + + clean_gray = to_gray_f(clean_img) + noisy_gray = to_gray_f(noisy_img) + + h, w = clean_gray.shape + aligned = clean_img.copy() + + # --- PHASE CORRELATION (coarse translation) --- + shift, response = cv2.phaseCorrelate(noisy_gray, clean_gray) # (dx, dy) + dx, dy = shift + M_trans = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + + # apply translation + aligned = cv2.warpAffine(clean_img, M_trans, (w, h), + flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP, + borderMode=cv2.BORDER_REFLECT) + + # --- optional ECC refinement (affine) --- + if refine: + warp_mode = cv2.MOTION_AFFINE + warp_matrix = np.eye(2, 3, dtype=np.float32) + criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 2000, 1e-8) + try: + cc, warp_matrix = cv2.findTransformECC( + noisy_gray, # template (target) + cv2.cvtColor(aligned, cv2.COLOR_BGR2GRAY).astype(np.float32) if aligned.ndim == 3 else aligned.astype(np.float32), + warp_matrix, + warp_mode, + criteria, + None, + 5 + ) + if verbose: + print(f"ECC converged: corr={cc:.5f}") + # compose the transforms: M_total = M_ECC @ M_trans + M1 = np.vstack([M_trans, [0, 0, 1]]) + M2 = np.vstack([warp_matrix, [0, 0, 1]]) + M_total = (M2 @ M1)[:2, :] + aligned = cv2.warpAffine(clean_img, M_total, (w, h), + flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP, + borderMode=cv2.BORDER_REFLECT) + except cv2.error as e: + if verbose: + print("ECC failed:", e) + M_total = M_trans + else: + M_total = M_trans + + # --- compute metrics --- + def safe_gray(img): + return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img + + clean_g = safe_gray(clean_img) + noisy_g = safe_gray(noisy_img) + aligned_g = safe_gray(aligned) + + before_psnr = psnr(noisy_g, clean_g, data_range=255) + after_psnr = psnr(noisy_g, aligned_g, data_range=255) + before_ssim = ssim(noisy_g, clean_g, data_range=255) + after_ssim = ssim(noisy_g, aligned_g, data_range=255) + + metrics = { + "PSNR_before": before_psnr, + "PSNR_after": after_psnr, + "SSIM_before": before_ssim, + "SSIM_after": after_ssim, + "dx": dx, + "dy": dy, + "response": response, + } + + # flatten warp matrix for CSV + for i in range(2): + for j in range(3): + metrics[f"m{i}{j}"] = float(M_total[i, j]) + + return aligned, M_total, metrics + + + +def apply_alignment(img, warp_params, interpolation=cv2.INTER_LANCZOS4): + """ + Applies a previously estimated affine warp to an image. + warp_params: dict with keys m00..m12 or a 2x3 numpy array. + """ + if isinstance(warp_params, dict): + M = np.array([ + [warp_params["m00"], warp_params["m01"], warp_params["m02"]], + [warp_params["m10"], warp_params["m11"], warp_params["m12"]], + ], dtype=np.float32) + else: + M = np.array(warp_params, dtype=np.float32) + + h, w = img.shape[:2] + aligned = cv2.warpAffine( + img.astype(np.float32), + M, + (w, h), + flags=interpolation + cv2.WARP_INVERSE_MAP, + borderMode=cv2.BORDER_REFLECT + ) + return aligned + +class AlignImages(Dataset): + def __init__(self, path, csv, crop_size=180, buffer=10, validation=False): + super().__init__() + self.df = pd.read_csv(os.path.join(path, csv)) + self.path = path + self.crop_size = crop_size + self.buffer = buffer + self.coordinate_iso = 6400 + self.validation=validation + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + + # Get Row Matrix + shape=(2,3) + cols = [f"m{i}{j}" for i in range(shape[0]) for j in range(shape[1])] + flat = np.array([row.pop(c) for c in cols], dtype=np.float32) + warp_matrix = flat.reshape(shape) + warp_matrix + + # Load images + bayer_path = f"{self.path}/{row.noisy_image}_bayer.jpg" + with imageio.imopen(bayer_path, "r") as image_resource: + bayer_data = image_resource.read() + + gt_path = f"{self.path}/{row.gt_image}.jpg" + with imageio.imopen(gt_path, "r") as image_resource: + gt_image = image_resource.read() + + demosaiced_noisy = demosaicing_CFA_Bayer_Malvar2004(bayer_data) + demosaiced_noisy = demosaiced_noisy.astype(np.uint8) + aligned, matrix, metrics = align_clean_to_noisy(gt_image, demosaiced_noisy, refine=False) + metrics['iso'] = row.iso + metrics['std'] = (demosaiced_noisy.astype(int) - aligned.astype(int)).std() + metrics['bayer_path'] = Path(bayer_path).name + metrics['gt_path'] = Path(gt_path).name + return gt_image, demosaiced_noisy, aligned, metrics \ No newline at end of file diff --git a/src/training/censored_fit.py b/src/training/censored_fit.py new file mode 100644 index 0000000..f5f49b9 --- /dev/null +++ b/src/training/censored_fit.py @@ -0,0 +1,92 @@ +import numpy as np +from scipy.stats import norm + +def censored_linear_fit_twosided(x, y, clip_low=None, clip_high=None, + max_iter=200, tol=1e-6, include_offset=True): + """ + Fit y ≈ a + b*x + ε, ε ~ N(0, σ²) under two-sided censoring: + clip_low ≤ y_true ≤ clip_high + Observed y are clipped to [clip_low, clip_high]. + Returns (a, b, sigma) estimated via EM. + + Parameters + ---------- + x, y : array_like + Input data. + clip_low, clip_high : float or None + Lower/upper clip levels. Can be None for one-sided clipping. + max_iter : int + Maximum EM iterations. + tol : float + Relative tolerance for convergence. + include_offset : bool + If False, forces intercept a=0 (fit y ≈ b*x). + + Returns + ------- + a, b, sigma : floats + Estimated regression parameters. + """ + x = np.asarray(x).ravel() + y = np.asarray(y).ravel() + mask = np.isfinite(x) & np.isfinite(y) + x, y = x[mask], y[mask] + + if len(x) < 3: + raise ValueError("Not enough data points.") + + # --- initial guess (ordinary least squares) --- + if include_offset: + A = np.vstack([np.ones_like(x), x]).T + a, b = np.linalg.lstsq(A, y, rcond=None)[0] + else: + b = np.dot(x, y) / np.dot(x, x) + a = 0.0 + sigma = np.std(y - (a + b*x)) + + for _ in range(max_iter): + mu = a + b*x + y_exp = y.copy() + + # Handle right-censoring (high clip) + if clip_high is not None: + high_mask = y >= clip_high - 1e-12 + if np.any(high_mask): + z = (clip_high - mu[high_mask]) / sigma + Phi = norm.cdf(z) + phi = norm.pdf(z) + one_minus_Phi = 1.0 - Phi + lambda_ = np.zeros_like(z) + valid = one_minus_Phi > 1e-15 + lambda_[valid] = phi[valid] / one_minus_Phi[valid] + y_exp[high_mask] = mu[high_mask] + sigma * lambda_ + + # Handle left-censoring (low clip) + if clip_low is not None: + low_mask = y <= clip_low + 1e-12 + if np.any(low_mask): + z = (clip_low - mu[low_mask]) / sigma + Phi = norm.cdf(z) + phi = norm.pdf(z) + lambda_ = np.zeros_like(z) + valid = Phi > 1e-15 + lambda_[valid] = -phi[valid] / Phi[valid] + y_exp[low_mask] = mu[low_mask] + sigma * lambda_ + + # M-step: re-fit with imputed expectations + if include_offset: + A = np.vstack([np.ones_like(x), x]).T + a_new, b_new = np.linalg.lstsq(A, y_exp, rcond=None)[0] + else: + b_new = np.dot(x, y_exp) / np.dot(x, x) + a_new = 0.0 + + sigma_new = np.std(y_exp - (a_new + b_new*x)) + + if np.allclose([a, b, sigma], [a_new, b_new, sigma_new], + rtol=tol, atol=tol): + break + + a, b, sigma = a_new, b_new, sigma_new + + return a, b, sigma diff --git a/src/training/image_utils.py b/src/training/image_utils.py new file mode 100644 index 0000000..337f072 --- /dev/null +++ b/src/training/image_utils.py @@ -0,0 +1,289 @@ +import numpy as np +from PIL import Image +from scipy.ndimage import convolve + +def simulate_sparse(image, pattern="RGGB", cfa_type="bayer"): + """ + Simulate a sparse CFA (Color Filter Array) from an RGB image. + + Args: + image: numpy array (3, H, W), RGB image in [0, 1] or [0, 255]. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer, + or ignored if cfa_type="xtrans". + cfa_type: "bayer" or "xtrans". + + Returns: + cfa: numpy array (r, H, W), sparse CFA image. + sparse_mask: numpy array (r, H, W), mask of pixels. + """ + _, H, W= image.shape + cfa = np.zeros((3, H, W), dtype=image.dtype) + sparse_mask = np.zeros((3, H, W), dtype=image.dtype) + if cfa_type == "bayer": + # 2×2 Bayer masks + masks = { + "RGGB": np.array([["R", "G"], ["G", "B"]]), + "BGGR": np.array([["B", "G"], ["G", "R"]]), + "GRBG": np.array([["G", "R"], ["B", "G"]]), + "GBRG": np.array([["G", "B"], ["R", "G"]]), + } + if pattern not in masks: + raise ValueError(f"Unknown Bayer pattern: {pattern}") + + mask = masks[pattern] + cmap = {"R": 0, "G": 1, "B": 2} + + for i in range(2): + for j in range(2): + ch = cmap[mask[i, j]] + cfa[ch, i::2, j::2] = image[ch, i::2, j::2] + sparse_mask[ch, i::2, j::2] = 1 + elif cfa_type == "xtrans": + # Fuji X-Trans 6×6 repeating pattern + xtrans_pattern = np.array([ + ["G","B","R","G","R","B"], + ["R","G","G","B","G","G"], + ["B","G","G","R","G","G"], + ["G","R","B","G","B","R"], + ["B","G","G","R","G","G"], + ["R","G","G","B","G","G"], + ]) + cmap = {"R":0, "G":1, "B":2} + + for i in range(6): + for j in range(6): + ch = cmap[xtrans_pattern[i, j]] + cfa[ch, i::6, j::6] = image[ch, i::6, j::6] + sparse_mask[ch, i::2, j::2] = 1 + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + return cfa, sparse_mask + + +def color_jitter_0_1(img, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1): + """ + Applies color jitter to a NumPy image array with pixel values scaled 0-1. + + Args: + img (np.ndarray): Input image as a NumPy array (H, W, 3) with values in [0, 1]. + brightness (float): Max variation for brightness factor. + contrast (float): Max variation for contrast factor. + saturation (float): Max variation for saturation factor. + hue (float): Max variation for hue factor. + + Returns: + np.ndarray: The jittered image array with values clipped to [0, 1]. + """ + # Ensure the image is a float type for calculations + img = img.astype(np.float32) + + # 1. Adjust brightness + if brightness > 0: + b_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) + img = img * b_factor + + # 2. Adjust contrast + if contrast > 0: + c_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) + img = img * c_factor + (1 - c_factor) * np.mean(img, axis=(0, 1)) + + # Convert to HSV to adjust saturation and hue. + # PIL expects uint8, so we scale from [0, 1] to [0, 255] + img_pil = Image.fromarray(np.clip(img * 255, 0, 255).astype(np.uint8)) + img_hsv = np.array(img_pil.convert('HSV'), dtype=np.float32) + + # 3. Adjust saturation + if saturation > 0: + s_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) + img_hsv[:, :, 1] *= s_factor + + # 4. Adjust hue + if hue > 0: + h_factor = np.random.uniform(-hue, hue) + # Hue in Pillow is 0-255, so we scale our factor + img_hsv[:, :, 0] += h_factor * 255.0 + + # Clip HSV values to [0, 255] range + img_hsv = np.clip(img_hsv, 0, 255) + + # Convert back to RGB and scale back to [0, 1] + img_jittered_pil = Image.fromarray(img_hsv.astype(np.uint8), 'HSV').convert('RGB') + jittered_img = np.array(img_jittered_pil, dtype=np.float32) / 255.0 + + # Final clipping to ensure output is in [0, 1] + final_img = np.clip(jittered_img, 0, 1) + + return final_img + + + + +def bilinear_demosaic(sparse, pattern="RGGB", cfa_type="bayer"): + """ + Simple bilinear demosaicing for Bayer and X-Trans CFAs. + + Args: + sparse: numpy array (3, H, W) with a sparse representation of the bayer image. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer. + cfa_type: "bayer" or "xtrans". + + Returns: + rgb: numpy array (H, W, 3), bilinearly demosaiced image. + """ + C, H, W = sparse.shape + rgb = np.zeros((C, H, W), dtype=sparse.dtype) + + if cfa_type == "bayer": + # Bilinear interpolation kernel + kernels = [ + np.array([[.25, .5, .25], [.5, 1, .5], [.25, .5, .25]], dtype=np.float32), + np.array([[0, .25, 0], [.25, 1, .25], [0, .25, 0]], dtype=np.float32), + np.array([[.25, .5, .25], [.5, 1, .5], [.25, .5, .25]], dtype=np.float32) + + ] + + + elif cfa_type == "xtrans": + # Bilinear interpolation kernel + kernel = np.array([ + [1, 2, 4, 2, 1], + [2, 4, 8, 4, 2], + [4, 8, 16, 8, 4], + [2, 4, 8, 4, 2], + [1, 2, 4, 2, 1],], dtype=np.float32) + kernel /= kernel.sum() + + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + # Interpolate each channel + for ch in range(3): + rgb[ch, ...] = convolve(sparse[ch], kernels[ch], mode="mirror") + + # Mulitply by 2 or 4 depending on how sparse the layer is. + # rgb[0,...] *= 4.0 # red + # rgb[1,...] *= 2.0 # green (already dense) + # rgb[2,...] *= 4.0 # blue + return rgb + + + + + +def bilinear_demosaic_torch(sparse, pattern="RGGB", cfa_type="bayer"): + """ + Simple bilinear demosaicing for Bayer and X-Trans CFAs (Torch version). + + Args: + sparse: torch tensor (3, H, W) with a sparse representation of the CFA image. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer. + cfa_type: "bayer" or "xtrans". + + Returns: + rgb: torch tensor (3, H, W), bilinearly demosaiced image. + """ + C, H, W = sparse.shape + device = sparse.device + dtype = sparse.dtype + + if cfa_type == "bayer": + kernels = [ + torch.tensor([[.25, .5, .25], + [.5, 1, .5], + [.25, .5, .25]], dtype=dtype, device=device), + torch.tensor([[0, .25, 0], + [.25, 1, .25], + [0, .25, 0]], dtype=dtype, device=device), + torch.tensor([[.25, .5, .25], + [.5, 1, .5], + [.25, .5, .25]], dtype=dtype, device=device) + ] + elif cfa_type == "xtrans": + kernel = torch.tensor([ + [1, 2, 4, 2, 1], + [2, 4, 8, 4, 2], + [4, 8, 16, 8, 4], + [2, 4, 8, 4, 2], + [1, 2, 4, 2, 1]], dtype=dtype, device=device) + kernel /= kernel.sum() + kernels = [kernel, kernel, kernel] + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + rgb = torch.zeros_like(sparse) + + # Apply each kernel using conv2d + for ch in range(3): + k = kernels[ch].unsqueeze(0).unsqueeze(0) # (1,1,H,W) + x = sparse[ch:ch+1].unsqueeze(0) # (1,1,H,W) + pad_h, pad_w = k.shape[-2]//2, k.shape[-1]//2 + rgb[ch] = F.conv2d(F.pad(x, (pad_w, pad_w, pad_h, pad_h), mode="reflect"), k)[0,0] + + return rgb + +def cfa_to_sparse(image, pattern="RGGB", cfa_type="bayer"): + """ + Make a sparse representation from a CFA + + Args: + image: numpy array (H, W), RGB image in [0, 1] or [0, 255]. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer, + or ignored if cfa_type="xtrans". + cfa_type: "bayer" or "xtrans". + + Returns: + cfa: numpy array (r, H, W), sparse CFA image. + sparse_mask: numpy array (r, H, W), mask of pixels. + """ + H, W= image.shape + cfa = np.zeros((3, H, W), dtype=image.dtype) + sparse_mask = np.zeros((3, H, W), dtype=image.dtype) + if cfa_type == "bayer": + # 2×2 Bayer masks + masks = { + "RGGB": np.array([["R", "G"], ["G", "B"]]), + "BGGR": np.array([["B", "G"], ["G", "R"]]), + "GRBG": np.array([["G", "R"], ["B", "G"]]), + "GBRG": np.array([["G", "B"], ["R", "G"]]), + } + if pattern not in masks: + raise ValueError(f"Unknown Bayer pattern: {pattern}") + + mask = masks[pattern] + cmap = {"R": 0, "G": 1, "B": 2} + + for i in range(2): + for j in range(2): + ch = cmap[mask[i, j]] + cfa[ch, i::2, j::2] = image[i::2, j::2] + sparse_mask[ch, i::2, j::2] = 1 + elif cfa_type == "xtrans": + # Fuji X-Trans 6×6 repeating pattern + xtrans_pattern = np.array([ + ["G","B","R","G","R","B"], + ["R","G","G","B","G","G"], + ["B","G","G","R","G","G"], + ["G","R","B","G","B","R"], + ["B","G","G","R","G","G"], + ["R","G","G","B","G","G"], + ]) + cmap = {"R":0, "G":1, "B":2} + + for i in range(6): + for j in range(6): + ch = cmap[xtrans_pattern[i, j]] + cfa[ch, i::6, j::6] = image[i::6, j::6] + sparse_mask[ch, i::2, j::2] = 1 + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + return cfa, sparse_mask + + +def apply_gamma(x, gamma=2.2): + return x ** (1 / gamma) + +def inverse_gamma_tone_curve(img: np.ndarray, gamma: float = 2.2) -> np.ndarray: + img = np.clip(img, 0, 1) + return np.power(img, gamma) \ No newline at end of file diff --git a/src/training/load_config.py b/src/training/load_config.py new file mode 100644 index 0000000..39b1bc4 --- /dev/null +++ b/src/training/load_config.py @@ -0,0 +1,14 @@ +import yaml + +def load_config(file_path='config.yaml'): + """Loads configuration data from a YAML file.""" + try: + with open(file_path, 'r') as f: + config = yaml.safe_load(f) + return config + except FileNotFoundError: + print(f"Error: Configuration file not found at {file_path}") + return None + except yaml.YAMLError as exc: + print(f"Error parsing YAML file: {exc}") + return None \ No newline at end of file diff --git a/src/Restorer/CombinedPerceptualLoss.py b/src/training/losses/CombinedPerceptualLoss.py similarity index 100% rename from src/Restorer/CombinedPerceptualLoss.py rename to src/training/losses/CombinedPerceptualLoss.py diff --git a/src/Restorer/losses/MultiScaleLoss.py b/src/training/losses/MultiScaleLoss.py similarity index 100% rename from src/Restorer/losses/MultiScaleLoss.py rename to src/training/losses/MultiScaleLoss.py diff --git a/src/training/losses/ShadowAwareLoss.py b/src/training/losses/ShadowAwareLoss.py new file mode 100644 index 0000000..cf56db1 --- /dev/null +++ b/src/training/losses/ShadowAwareLoss.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +from pytorch_msssim import ms_ssim +from src.training.losses.CombinedPerceptualLoss import VGGPerceptualLoss +import torchvision + +class ShadowAwareLoss(nn.Module): + def __init__(self, + alpha=0.2, + beta=5.0, + l1_weight=0.16, + ssim_weight=0.84, + tv_weight=0.0, + vgg_loss_weight=0.0, + apply_gamma_fn=None, + vgg_feature_extractor=None, + percept_loss_weight = 0, + device=None, + sharpness_loss_weight=0): + """ + Shadow-aware image restoration loss. + + Args: + alpha: Minimum tone weight for bright pixels. + beta: Controls how quickly weight increases in shadows. + l1_weight, ssim_weight, tv_weight, vgg_loss_weight: Loss scaling factors. + apply_gamma_fn: Optional function to apply gamma correction to input tensors. + vgg_feature_extractor: Optional VGG feature extractor returning feature maps. + device: Optional device to move inputs and buffers to. + """ + super().__init__() + self.alpha = alpha + self.beta = beta + self.l1_weight = l1_weight + self.ssim_weight = ssim_weight + self.tv_weight = tv_weight + self.vgg_loss_weight = vgg_loss_weight + self.apply_gamma_fn = apply_gamma_fn + self.vfe = vgg_feature_extractor + self.device = device + self.percept_loss_weight = percept_loss_weight + self.VGGPerceptualLoss = VGGPerceptualLoss() + self.sharpness_loss_weight = sharpness_loss_weight + + if device is not None: + self.to(device) + + def forward(self, pred, target): + """ + Args: + pred: [B, C, H, W] restored image in [0,1] + target: [B, C, H, W] ground truth in [0,1] + """ + if self.apply_gamma_fn is not None: + pred = torch.clamp(pred, 1e-6, 1.0) + target = torch.clamp(target, 1e-6, 1.0) + pred = self.apply_gamma_fn(pred).clamp(1e-6, 1) + target = self.apply_gamma_fn(target).clamp(1e-6, 1) + + # Convert to luminance (BT.709) + lum = 0.2126 * target[:, 0] + 0.7152 * target[:, 1] + 0.0722 * target[:, 2] # [B, H, W] + tone_weight = self.alpha + (1.0 - self.alpha) * torch.exp(-self.beta * lum) + tone_weight = tone_weight.unsqueeze(1) # [B, 1, H, W] + + # Weighted L1 loss + l1 = (tone_weight * torch.abs(pred - target)).mean() + + # Weighted MS-SSIM loss + ssim = 1 - ms_ssim(pred, target, data_range=1.0, size_average=True) + + # TV loss only in low-light areas + shadow_mask = (lum < 0.2).float().unsqueeze(1) + dx = torch.abs(pred[:, :, :, 1:] - pred[:, :, :, :-1]) + dy = torch.abs(pred[:, :, 1:, :] - pred[:, :, :-1, :]) + tv = ((shadow_mask[:, :, :, 1:] * dx).mean() + + (shadow_mask[:, :, 1:, :] * dy).mean()) + + # Optional VGG perceptual loss + vgg_loss_val = 0 + if self.vgg_loss_weight != 0 and self.vfe is not None: + with torch.no_grad(): + pred_features = self.vfe(pred) + target_features = self.vfe(target) + vgg_loss_val = nn.functional.mse_loss(pred_features[0], target_features[0]) + + percept_loss = 0 + if self.percept_loss_weight: + percept_loss = self.VGGPerceptualLoss(pred, target) + + sharpness_loss_value = 0 + if self.sharpness_loss_weight: + sharpness_loss_value += sharpness_loss(pred, target) + + # Combine weighted terms + total_loss = ( + self.l1_weight * l1 + + self.ssim_weight * ssim + + self.tv_weight * tv + + self.vgg_loss_weight * vgg_loss_val + + self.percept_loss_weight * percept_loss + + self.sharpness_loss_weight * sharpness_loss_value + ) + + return total_loss + +def sharpness_loss(pred, target, loss_func = torch.nn.functional.l1_loss): + loss = loss_func(pred, target) + pred_prime = torchvision.transforms.functional.gaussian_blur(pred, kernel_size=[5, 5], sigma=[1.0, 1.0]) + target_prime = torchvision.transforms.functional.gaussian_blur(target, kernel_size=[5, 5], sigma=[1.0, 1.0]) + loss += loss_func(pred-pred_prime, target-target_prime) + return loss diff --git a/src/Restorer/losses/__init__.py b/src/training/losses/__init__.py similarity index 100% rename from src/Restorer/losses/__init__.py rename to src/training/losses/__init__.py diff --git a/src/training/rggb_loop.py b/src/training/rggb_loop.py new file mode 100644 index 0000000..51e3fea --- /dev/null +++ b/src/training/rggb_loop.py @@ -0,0 +1,120 @@ +from time import perf_counter +import time +from tqdm import tqdm +import torch +import torch.nn as nn +from src.training.utils import apply_gamma_torch + +def make_conditioning(conditioning, device): + B = conditioning.shape[0] + conditioning_extended = torch.zeros(B, 1).to(device) + conditioning_extended[:, 0] = conditioning[:, 0] + return conditioning_extended + + +def train_one_epoch_rggb(epoch, _model, _optimizer, _outname, _loader, _device, _loss_func, _clipping, log_interval = 10, sleep=0.2): + _model.train() + total_loss, n_images, total_final_image_loss = 0.0, 0, 0.0 + start = perf_counter() + pbar = tqdm(enumerate(_loader), total=len(_loader), desc=f"Train Epoch {epoch}") + + for batch_idx, (output) in pbar: + # for output in train_loader: + noisy = output['noisy'].float().to(_device) + conditioning = output['conditioning'].float().to(_device) + gt = output['aligned'].float().to(_device) + rggb = output['rggb'].float().to(_device) + + conditioning = make_conditioning(conditioning, _device) + with torch.autocast(device_type="mps", dtype=torch.bfloat16): + pred = _model(rggb, conditioning, noisy) + + loss = _loss_func(pred, gt) + _optimizer.zero_grad(set_to_none=True) + loss.backward() + torch.nn.utils.clip_grad_norm_(_model.parameters(), _clipping) + _optimizer.step() + + total_loss += float(loss.detach().cpu()) + n_images += gt.shape[0] + + # Testing final image quality + final_image_loss = nn.functional.l1_loss(pred, gt) + total_final_image_loss += final_image_loss.item() + del loss, pred, final_image_loss + torch.mps.empty_cache() + + if (batch_idx + 1) % log_interval == 0: + pbar.set_postfix({"loss": f"{total_loss/n_images:.4f}"}) + + time.sleep(sleep) + + torch.save(_model.state_dict(), _outname) + + print(f"[Epoch {epoch}] " + f"Train loss: {total_loss/n_images:.6f} " + f"Final image val loss: {total_final_image_loss/n_images:.6f} " + f"Time: {perf_counter()-start:.1f}s " + f"Images: {n_images}") + + return total_loss / max(1, n_images), perf_counter()-start + + + + +def visualize(idxs, _model, dataset, _device, _loss_func): + import matplotlib.pyplot as plt + _model.train() + total_loss, n_images, total_final_image_loss = 0.0, 0, 0.0 + start = perf_counter() + + + for idx in idxs: + # for output in train_loader: + row = dataset[idx] + noisy = row['noisy'].unsqueeze(0).float().to(_device) + conditioning = row['conditioning'].float().unsqueeze(0).to(_device) + gt = row['aligned'].unsqueeze(0).float().to(_device) + rggb = row['rggb'].unsqueeze(0).float().to(_device) + + conditioning = make_conditioning(conditioning, _device) + with torch.no_grad(): + with torch.autocast(device_type="mps", dtype=torch.bfloat16): + pred = _model(rggb, conditioning, noisy) + loss = _loss_func(pred, gt) + + + total_loss += float(loss.detach().cpu()) + n_images += gt.shape[0] + + # Testing final image quality + final_image_loss = nn.functional.l1_loss(pred, gt) + total_final_image_loss += final_image_loss.item() + + plt.subplots(2, 2, figsize=(15, 15)) + + plt.subplot(2, 2, 1) + pred = apply_gamma_torch(pred[0].cpu().permute(1, 2, 0)) + plt.imshow(pred) + + plt.subplot(2, 2, 2) + noisy = apply_gamma_torch(noisy[0].cpu().permute(1, 2, 0)) + plt.imshow(noisy) + + plt.subplot(2, 2, 3) + gt = apply_gamma_torch(gt[0].cpu().permute(1, 2, 0)) + plt.imshow(gt) + + plt.subplot(2, 2, 4) + plt.imshow(pred - gt + 0.5) + plt.show() + plt.clf() + + n_images = len(idxs) + print( + f"Train loss: {total_loss/n_images:.6f} " + f"Final image val loss: {total_final_image_loss/n_images:.6f} " + f"Time: {perf_counter()-start:.1f}s " + f"Images: {n_images}") + + return total_loss / max(1, n_images), perf_counter()-start \ No newline at end of file diff --git a/src/training/train_loop.py b/src/training/train_loop.py new file mode 100644 index 0000000..e79a5d8 --- /dev/null +++ b/src/training/train_loop.py @@ -0,0 +1,134 @@ +from time import perf_counter +import time +from tqdm import tqdm +import torch +import torch.nn as nn +from src.training.utils import apply_gamma_torch +import mlflow + +def make_conditioning(conditioning, device): + B = conditioning.shape[0] + conditioning_extended = torch.zeros(B, 1).to(device) + conditioning_extended[:, 0] = conditioning[:, 0] + return conditioning_extended + + +def train_one_epoch(epoch, _model, _optimizer, _loader, _device, _loss_func, _clipping, + log_interval = 10, sleep=0.0, rggb=False, + max_batches=0): + _model.train() + total_loss, n_images, total_l1_loss = 0.0, 0, 0.0 + start = perf_counter() + pbar = tqdm(enumerate(_loader), total=len(_loader), desc=f"Train Epoch {epoch}") + + for batch_idx, (output) in pbar: + noisy = output['noisy'].float().to(_device) + conditioning = output['conditioning'].float().to(_device) + gt = output['aligned'].float().to(_device) + input = output['sparse'].float().to(_device) + if rggb: + input = output['rggb'].float().to(_device) + conditioning = make_conditioning(conditioning, _device) + + _optimizer.zero_grad(set_to_none=True) + pred = _model(input, conditioning, noisy) + + loss = _loss_func(pred, gt) + _optimizer.zero_grad(set_to_none=True) + loss.backward() + torch.nn.utils.clip_grad_norm_(_model.parameters(), _clipping) + _optimizer.step() + + total_loss += float(loss.detach().cpu()) + n_images += gt.shape[0] + + # Testing final image quality + final_image_loss = float(nn.functional.l1_loss(pred, gt).detach().cpu()) + total_l1_loss += final_image_loss + del loss, pred, final_image_loss + torch.mps.empty_cache() + + if (batch_idx + 1) % log_interval == 0: + pbar.set_postfix({"loss": f"{total_loss/n_images:.4f}"}) + + if (max_batches > 0) and (batch_idx+1 > max_batches): break + time.sleep(sleep) + + train_time = perf_counter()-start + print(f"[Epoch {epoch}] " + f"Train loss: {total_loss/n_images:.6f} " + f"L1 loss: {total_l1_loss/n_images:.6f} " + f"Time: {train_time:.1f}s " + f"Images: {n_images}") + mlflow.log_metric("train_loss", total_loss/n_images, step=epoch) + mlflow.log_metric("l1_loss", total_l1_loss/n_images, step=epoch) + mlflow.log_metric("epoch_duration_s", train_time, step=epoch) + mlflow.log_metric("learning_rate", _optimizer.param_groups[0]['lr'], step=epoch) + + return total_loss / max(1, n_images), perf_counter()-start + + + + +def visualize(idxs, _model, dataset, _device, _loss_func, rggb=False): + import matplotlib.pyplot as plt + _model.train() + total_loss, n_images, total_final_image_loss = 0.0, 0, 0.0 + start = perf_counter() + + for idx in idxs: + row = dataset[idx] + noisy = row['noisy'].unsqueeze(0).float().to(_device) + conditioning = row['conditioning'].float().unsqueeze(0).to(_device) + gt = row['aligned'].unsqueeze(0).float().to(_device) + input = row['sparse'].unsqueeze(0).float().to(_device) + if rggb: + input = row['rggb'].unsqueeze(0).float().to(_device) + + conditioning = make_conditioning(conditioning, _device) + + with torch.no_grad(): + with torch.autocast(device_type="mps", dtype=torch.bfloat16): + pred = _model(input, conditioning, noisy) + loss = _loss_func(pred, gt) + + total_loss += float(loss.detach().cpu()) + n_images += gt.shape[0] + # Testing final image quality + final_image_loss = nn.functional.l1_loss(pred, gt) + total_final_image_loss += final_image_loss.item() + + plt.subplots(2, 3, figsize=(30, 15)) + + plt.subplot(2, 3, 1) + pred = apply_gamma_torch(pred[0].cpu().permute(1, 2, 0)) + plt.imshow(pred) + + plt.subplot(2, 3, 2) + noisy = apply_gamma_torch(noisy[0].cpu().permute(1, 2, 0)) + plt.imshow(noisy) + + plt.subplot(2, 3, 3) + gt = apply_gamma_torch(gt[0].cpu().permute(1, 2, 0)) + plt.imshow(gt) + + plt.subplot(2, 3, 4) + plt.imshow(pred - gt + 0.5) + + + plt.subplot(2, 3, 5) + plt.imshow(noisy - pred + 0.5) + + plt.subplot(2, 3, 6) + plt.imshow(noisy - gt + 0.5) + plt.show() + plt.clf() + + n_images = len(idxs) + print( + f"Train loss: {total_loss/n_images:.6f} " + f"Final image val loss: {total_final_image_loss/n_images:.6f} " + f"Time: {perf_counter()-start:.1f}s " + f"Images: {n_images}") + + return total_loss / max(1, n_images), total_final_image_loss / max(1, n_images), perf_counter()-start \ No newline at end of file diff --git a/src/training/utils.py b/src/training/utils.py new file mode 100644 index 0000000..7ef09fb --- /dev/null +++ b/src/training/utils.py @@ -0,0 +1,295 @@ +import numpy as np +from PIL import Image +from scipy.ndimage import convolve + +def simulate_sparse(image, pattern="RGGB", cfa_type="bayer"): + """ + Simulate a sparse CFA (Color Filter Array) from an RGB image. + + Args: + image: numpy array (3, H, W), RGB image in [0, 1] or [0, 255]. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer, + or ignored if cfa_type="xtrans". + cfa_type: "bayer" or "xtrans". + + Returns: + cfa: numpy array (r, H, W), sparse CFA image. + sparse_mask: numpy array (r, H, W), mask of pixels. + """ + _, H, W= image.shape + cfa = np.zeros((3, H, W), dtype=image.dtype) + sparse_mask = np.zeros((3, H, W), dtype=image.dtype) + if cfa_type == "bayer": + # 2×2 Bayer masks + masks = { + "RGGB": np.array([["R", "G"], ["G", "B"]]), + "BGGR": np.array([["B", "G"], ["G", "R"]]), + "GRBG": np.array([["G", "R"], ["B", "G"]]), + "GBRG": np.array([["G", "B"], ["R", "G"]]), + } + if pattern not in masks: + raise ValueError(f"Unknown Bayer pattern: {pattern}") + + mask = masks[pattern] + cmap = {"R": 0, "G": 1, "B": 2} + + for i in range(2): + for j in range(2): + ch = cmap[mask[i, j]] + cfa[ch, i::2, j::2] = image[ch, i::2, j::2] + sparse_mask[ch, i::2, j::2] = 1 + elif cfa_type == "xtrans": + # Fuji X-Trans 6×6 repeating pattern + xtrans_pattern = np.array([ + ["G","B","R","G","R","B"], + ["R","G","G","B","G","G"], + ["B","G","G","R","G","G"], + ["G","R","B","G","B","R"], + ["B","G","G","R","G","G"], + ["R","G","G","B","G","G"], + ]) + cmap = {"R":0, "G":1, "B":2} + + for i in range(6): + for j in range(6): + ch = cmap[xtrans_pattern[i, j]] + cfa[ch, i::6, j::6] = image[ch, i::6, j::6] + sparse_mask[ch, i::2, j::2] = 1 + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + return cfa, sparse_mask + + +def color_jitter_0_1(img, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1): + """ + Applies color jitter to a NumPy image array with pixel values scaled 0-1. + + Args: + img (np.ndarray): Input image as a NumPy array (H, W, 3) with values in [0, 1]. + brightness (float): Max variation for brightness factor. + contrast (float): Max variation for contrast factor. + saturation (float): Max variation for saturation factor. + hue (float): Max variation for hue factor. + + Returns: + np.ndarray: The jittered image array with values clipped to [0, 1]. + """ + # Ensure the image is a float type for calculations + img = img.astype(np.float32) + + # 1. Adjust brightness + if brightness > 0: + b_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) + img = img * b_factor + + # 2. Adjust contrast + if contrast > 0: + c_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) + img = img * c_factor + (1 - c_factor) * np.mean(img, axis=(0, 1)) + + # Convert to HSV to adjust saturation and hue. + # PIL expects uint8, so we scale from [0, 1] to [0, 255] + img_pil = Image.fromarray(np.clip(img * 255, 0, 255).astype(np.uint8)) + img_hsv = np.array(img_pil.convert('HSV'), dtype=np.float32) + + # 3. Adjust saturation + if saturation > 0: + s_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) + img_hsv[:, :, 1] *= s_factor + + # 4. Adjust hue + if hue > 0: + h_factor = np.random.uniform(-hue, hue) + # Hue in Pillow is 0-255, so we scale our factor + img_hsv[:, :, 0] += h_factor * 255.0 + + # Clip HSV values to [0, 255] range + img_hsv = np.clip(img_hsv, 0, 255) + + # Convert back to RGB and scale back to [0, 1] + img_jittered_pil = Image.fromarray(img_hsv.astype(np.uint8), 'HSV').convert('RGB') + jittered_img = np.array(img_jittered_pil, dtype=np.float32) / 255.0 + + # Final clipping to ensure output is in [0, 1] + final_img = np.clip(jittered_img, 0, 1) + + return final_img + + + + +def bilinear_demosaic(sparse, pattern="RGGB", cfa_type="bayer"): + """ + Simple bilinear demosaicing for Bayer and X-Trans CFAs. + + Args: + sparse: numpy array (3, H, W) with a sparse representation of the bayer image. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer. + cfa_type: "bayer" or "xtrans". + + Returns: + rgb: numpy array (H, W, 3), bilinearly demosaiced image. + """ + C, H, W = sparse.shape + rgb = np.zeros((C, H, W), dtype=sparse.dtype) + + if cfa_type == "bayer": + # Bilinear interpolation kernel + kernels = [ + np.array([[.25, .5, .25], [.5, 1, .5], [.25, .5, .25]], dtype=np.float32), + np.array([[0, .25, 0], [.25, 1, .25], [0, .25, 0]], dtype=np.float32), + np.array([[.25, .5, .25], [.5, 1, .5], [.25, .5, .25]], dtype=np.float32) + + ] + + + elif cfa_type == "xtrans": + # Bilinear interpolation kernel + kernel = np.array([ + [1, 2, 4, 2, 1], + [2, 4, 8, 4, 2], + [4, 8, 16, 8, 4], + [2, 4, 8, 4, 2], + [1, 2, 4, 2, 1],], dtype=np.float32) + kernel /= kernel.sum() + + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + # Interpolate each channel + for ch in range(3): + rgb[ch, ...] = convolve(sparse[ch], kernels[ch], mode="mirror") + + # Mulitply by 2 or 4 depending on how sparse the layer is. + # rgb[0,...] *= 4.0 # red + # rgb[1,...] *= 2.0 # green (already dense) + # rgb[2,...] *= 4.0 # blue + return rgb + + + + + +def bilinear_demosaic_torch(sparse, pattern="RGGB", cfa_type="bayer"): + """ + Simple bilinear demosaicing for Bayer and X-Trans CFAs (Torch version). + + Args: + sparse: torch tensor (3, H, W) with a sparse representation of the CFA image. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer. + cfa_type: "bayer" or "xtrans". + + Returns: + rgb: torch tensor (3, H, W), bilinearly demosaiced image. + """ + C, H, W = sparse.shape + device = sparse.device + dtype = sparse.dtype + + if cfa_type == "bayer": + kernels = [ + torch.tensor([[.25, .5, .25], + [.5, 1, .5], + [.25, .5, .25]], dtype=dtype, device=device), + torch.tensor([[0, .25, 0], + [.25, 1, .25], + [0, .25, 0]], dtype=dtype, device=device), + torch.tensor([[.25, .5, .25], + [.5, 1, .5], + [.25, .5, .25]], dtype=dtype, device=device) + ] + elif cfa_type == "xtrans": + kernel = torch.tensor([ + [1, 2, 4, 2, 1], + [2, 4, 8, 4, 2], + [4, 8, 16, 8, 4], + [2, 4, 8, 4, 2], + [1, 2, 4, 2, 1]], dtype=dtype, device=device) + kernel /= kernel.sum() + kernels = [kernel, kernel, kernel] + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + rgb = torch.zeros_like(sparse) + + # Apply each kernel using conv2d + for ch in range(3): + k = kernels[ch].unsqueeze(0).unsqueeze(0) # (1,1,H,W) + x = sparse[ch:ch+1].unsqueeze(0) # (1,1,H,W) + pad_h, pad_w = k.shape[-2]//2, k.shape[-1]//2 + rgb[ch] = F.conv2d(F.pad(x, (pad_w, pad_w, pad_h, pad_h), mode="reflect"), k)[0,0] + + return rgb + +def cfa_to_sparse(image, pattern="RGGB", cfa_type="bayer"): + """ + Make a sparse representation from a CFA + + Args: + image: numpy array (H, W), RGB image in [0, 1] or [0, 255]. + pattern: CFA pattern string, one of {"RGGB","BGGR","GRBG","GBRG"} for Bayer, + or ignored if cfa_type="xtrans". + cfa_type: "bayer" or "xtrans". + + Returns: + cfa: numpy array (r, H, W), sparse CFA image. + sparse_mask: numpy array (r, H, W), mask of pixels. + """ + H, W= image.shape + cfa = np.zeros((3, H, W), dtype=image.dtype) + sparse_mask = np.zeros((3, H, W), dtype=image.dtype) + if cfa_type == "bayer": + # 2×2 Bayer masks + masks = { + "RGGB": np.array([["R", "G"], ["G", "B"]]), + "BGGR": np.array([["B", "G"], ["G", "R"]]), + "GRBG": np.array([["G", "R"], ["B", "G"]]), + "GBRG": np.array([["G", "B"], ["R", "G"]]), + } + if pattern not in masks: + raise ValueError(f"Unknown Bayer pattern: {pattern}") + + mask = masks[pattern] + cmap = {"R": 0, "G": 1, "B": 2} + + for i in range(2): + for j in range(2): + ch = cmap[mask[i, j]] + cfa[ch, i::2, j::2] = image[i::2, j::2] + sparse_mask[ch, i::2, j::2] = 1 + elif cfa_type == "xtrans": + # Fuji X-Trans 6×6 repeating pattern + xtrans_pattern = np.array([ + ["G","B","R","G","R","B"], + ["R","G","G","B","G","G"], + ["B","G","G","R","G","G"], + ["G","R","B","G","B","R"], + ["B","G","G","R","G","G"], + ["R","G","G","B","G","G"], + ]) + cmap = {"R":0, "G":1, "B":2} + + for i in range(6): + for j in range(6): + ch = cmap[xtrans_pattern[i, j]] + cfa[ch, i::6, j::6] = image[i::6, j::6] + sparse_mask[ch, i::2, j::2] = 1 + else: + raise ValueError(f"Unknown CFA type: {cfa_type}") + + return cfa, sparse_mask + +import torch + +def apply_gamma_torch(img: torch.tensor, gamma: float = 2.2) -> torch.tensor: + img = img.clamp(0, 1) + return (img ** (1.0 / gamma)).clamp(0, 1) + +def apply_gamma(img: np.ndarray, gamma: float = 2.2) -> np.ndarray: + img = np.clip(img, 0, 1) + return np.power(img, 1.0 / gamma) + +def inverse_gamma_tone_curve(img: np.ndarray, gamma: float = 2.2) -> np.ndarray: + img = np.clip(img, 0, 1) + return np.power(img, gamma) \ No newline at end of file