Skip to content

JJLibra/SALAD-Pan

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

162 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SALAD-Pan: Sensor-Agnostic Latent Adaptive Diffusion for Pan-Sharpening

Junjie Li · Congyang Ou · Haokui Zhang · Guoting Wei · Shengqin Jiang · Ying Li · Chunhua Shen

arXiv Hugging Face Spaces Project Website Python Version PyTorch Version

Structure
Given a PAN–LRMS image pair, SALAD-Pan fine-tunes a pre-trained diffusion model to generate a HRMS image.

News

  • [02/01/2026] Code will be released soon!

Contents

Setup

Requirements

git clone https://github.com/JJLibra/SALAD-Pan.git
cd SALAD-Pan

conda create -n saladpan python=3.10 -y
conda activate saladpan

# This project depends on a modified local version of `diffusers` under `./diffusers`.
cd diffusers
pip install -e .

cd ..
pip install -r requirements.txt

Initialize an 🤗 Accelerate environment with:

accelerate config

Or use a default Accelerate configuration without answering environment questions:

accelerate config default

Weights

We provide two-stage checkpoints:

  • Stage I (Band-VAE): checkpoints/vae.safetensors Download: Hugging Face

  • Stage II (Latent Diffusion): runs on top of Stable Diffusion in the Band-VAE latent space.

Repository Layout

  • configs/: training and inference YAML configurations.
  • core/: project model components and diffusion pipeline implementation.
  • utils/: data prep and metric utilities.
  • scripts/: convenience shell launchers.
  • data/: dataset root and expected structure notes.
  • checkpoints/: local checkpoint storage.
  • base/stable-diffusion-v1-5/: local SD v1.5 base model files.

Usage

Training

We train the model in two stages.

  • Stage I (VAE pretraining)
accelerate launch --config_file configs/accelerate.yaml train_vae.py --config configs/train_vae.yaml
  • Stage II (Diffusion + Adapter training)
accelerate launch --config_file configs/accelerate.yaml train_diffusion.py --config configs/train_diffusion.yaml

Note: Training usually takes 40k–50k steps, which is about 1–2 days on eight RTX 4090 GPUs in fp16. Reduce batch_size if your GPU memory is limited.

Inference

Once training is finished, run inference:

import torch
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, UniPCMultistepScheduler
from transformers import AutoTokenizer, CLIPTextModel
from core.components.salad_pan import DualBranchXSAdapter, UNetDualBranchXSModel
from core.pipelines.salad_pan import StableDiffusionDualBranchXSPipeline

device = "cuda"
dtype = torch.float16

base_model = "base/stable-diffusion-v1-5"
vae_path = "output/vae_c1_gf2_qb_wv3"
adapter_weights = "output/diffusion_model/dual_branch_xs_adapter.pt"

tokenizer = AutoTokenizer.from_pretrained(base_model, subfolder="tokenizer", use_fast=False, local_files_only=True)
text_encoder = CLIPTextModel.from_pretrained(base_model, subfolder="text_encoder", local_files_only=True)
vae = AutoencoderKL.from_pretrained(vae_path, local_files_only=True)
base_unet = UNet2DConditionModel.from_pretrained(base_model, subfolder="unet", local_files_only=True)

adapter = DualBranchXSAdapter.from_unet(base_unet, size_ratio=0.25, conditioning_channels=5, conditioning_channel_order="rgb")
unet = UNetDualBranchXSModel.from_unet(base_unet, adapter=adapter)
unet.load_state_dict(torch.load(adapter_weights, map_location="cpu"), strict=False)

scheduler = UniPCMultistepScheduler.from_config(
    DDPMScheduler.from_pretrained(base_model, subfolder="scheduler", local_files_only=True).config
)

pipe = StableDiffusionDualBranchXSPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    adapter=None,
    scheduler=scheduler,
    safety_checker=None,
    feature_extractor=None,
    requires_safety_checker=False,
).to(device)

prompt = ["GaoFen-2 satellite Band Red 630-690nm"]
conditioning = torch.randn(1, 5, 128, 128, device=device, dtype=dtype)  # replace with real [LR*4, PAN]

image = pipe(
    prompt=prompt,
    image=conditioning,
    num_inference_steps=20,
    guidance_scale=1.0,
    conditioning_scale_spa=1.0,
    conditioning_scale_spe=1.0,
    output_type="pil",
).images[0]

image.save("salad_pan_demo.png")

Installing xformers is highly recommended for better GPU efficiency and speed. To enable it, set enable_xformers_memory_efficient_attention=True.

Results

🚨 We strongly recommend visiting our project website for a better reading experience.

Quantitative Results

Table 1. Quantitative results on the WorldView-3 (WV3) dataset. Best results are in bold.

Models Pub/Year Q8 ↑ SAM ↓ ERGAS ↓ SCC ↑ Dλ ↓ Ds ↓ HQNR ↑
PaNNetICCV’170.891±0.0453.613±0.7872.664±0.3470.943±0.0180.017±0.0080.047±0.0140.937±0.015
FusionNetTGRS’200.904±0.0923.324±0.4112.465±0.6030.958±0.0230.024±0.0110.036±0.0160.940±0.019
LAGConvAAAI’220.910±0.1143.104±1.1192.300±0.9110.980±0.0430.036±0.0090.032±0.0160.934±0.011
BiMPANACMM’230.915±0.0872.984±0.6012.257±0.5520.984±0.0050.017±0.0190.035±0.0150.949±0.026
ARConvCVPR’250.916±0.0832.858±0.5902.117±0.5280.989±0.0140.014±0.0060.030±0.0070.958±0.010
WFANETAAAI’250.917±0.0882.855±0.6182.095±0.4220.989±0.0110.012±0.0070.031±0.0090.957±0.010
PanDiffTGRS’230.898±0.0903.297±0.2352.467±0.1660.980±0.0190.027±0.1080.054±0.0470.920±0.077
SSDiffNeurIPS’240.915±0.0862.843±0.5292.106±0.4160.986±0.0040.013±0.0050.031±0.0030.956±0.016
SGDiffCVPR’250.921±0.0822.771±0.5112.044±0.4490.987±0.0090.012±0.0050.027±0.0030.960±0.006
SALAD‑PanOurs0.924±0.0642.689±0.1351.839±0.2110.989±0.0070.010±0.0080.021±0.0040.965±0.007

Table 2. Quantitative results on the QuickBird (QB) dataset. Best results are in bold.

Models Pub/Year Q4 ↑ SAM ↓ ERGAS ↓ SCC ↑ Dλ ↓ Ds ↓ HQNR ↑
PaNNetICCV’170.885±0.1185.791±0.9955.863±0.4130.948±0.0210.059±0.0170.061±0.0100.883±0.025
FusionNetTGRS’200.925±0.0874.923±0.8124.159±0.3510.956±0.0180.059±0.0190.052±0.0090.892±0.022
LAGConvAAAI’220.916±0.1304.370±0.7203.740±0.2900.959±0.0470.085±0.0240.068±0.0140.853±0.018
BiMPANACMM’230.931±0.0914.586±0.8213.840±0.3190.980±0.0080.026±0.0200.040±0.0130.935±0.030
ARConvCVPR’250.936±0.0884.453±0.4993.649±0.4010.987±0.0090.019±0.0140.034±0.0170.948±0.042
WFANETAAAI’250.935±0.0924.490±0.5823.604±0.3370.986±0.0080.019±0.0160.033±0.0190.948±0.037
PanDiffTGRS’230.934±0.0954.575±0.2553.742±0.3530.980±0.0070.058±0.0150.064±0.0200.881±0.075
SSDiffNeurIPS’240.934±0.0944.464±0.7473.632±0.2750.982±0.0080.031±0.0110.036±0.0130.934±0.021
SGDiffCVPR’250.938±0.0874.353±0.7413.578±0.2900.983±0.0070.023±0.0130.043±0.0120.934±0.011
SALAD‑PanOurs0.939±0.0884.198±0.5263.251±0.2880.984±0.0090.017±0.0110.026±0.0090.957±0.010

Table 3. Quantitative results on the GaoFen-2 (GF2) dataset. Best results are in bold.

Models Pub/Year Q4 ↑ SAM ↓ ERGAS ↓ SCC ↑ Dλ ↓ Ds ↓ HQNR ↑
PaNNetICCV’170.967±0.0130.997±0.0220.919±0.0390.973±0.0110.017±0.0120.047±0.0120.937±0.023
FusionNetTGRS’200.964±0.0140.974±0.0350.988±0.0720.971±0.0120.040±0.0130.101±0.0140.863±0.018
LAGConvAAAI’220.970±0.0111.080±0.0230.910±0.0450.977±0.0060.033±0.0130.079±0.0130.891±0.021
BiMPANACMM’230.965±0.0200.902±0.0660.881±0.0580.972±0.0180.032±0.0150.051±0.0140.918±0.019
ARConvCVPR’250.982±0.0130.710±0.1490.645±0.1270.994±0.0050.007±0.0050.029±0.0190.963±0.018
WFANETAAAI’250.981±0.0070.751±0.0820.657±0.0740.994±0.0020.003±0.0030.032±0.0210.964±0.020
PanDiffTGRS’230.979±0.0110.888±0.0370.746±0.0310.988±0.0030.027±0.0110.073±0.0130.903±0.025
SSDiffNeurIPS’240.983±0.0070.670±0.1240.604±0.1080.991±0.0060.016±0.0090.027±0.0270.957±0.010
SGDiffCVPR’250.980±0.0110.708±0.1190.668±0.0940.989±0.0050.020±0.0130.024±0.0220.959±0.011
SALAD‑PanOurs0.982±0.0100.667±0.0510.592±0.0880.991±0.0030.005±0.0020.022±0.0140.973±0.010

Qualitative Comparison

Reduced Resolution
Visual comparison on the WorldView-3 (WV3) and QuickBird (QB) datasets at reduced resolution.

Full Resolution
Visual comparison on the WorldView-3 (WV3) and QuickBird (QB) datasets at full resolution.

Efficiency Comparison (RR, QB)

Diffusion-based Methods SAM ↓ ERGAS ↓ NFE Latency (s) ↓
PanDiff 4.575±0.255 3.742±0.353 1000 356.63±1.98
SSDiff 4.464±0.747 3.632±0.275 10 10.10±0.21
SGDiff 4.353±0.741 3.578±0.290 50 6.64±0.09
SALAD-Pan 4.198±0.526 3.251±0.288 20 3.36±0.07

Latency is reported as mean ± std over 10 runs (warmup = 3), with batch size = 1, evaluated on the QB dataset under the reduced-resolution (RR) protocol on an RTX 4090 GPU.

Citation

If you find our work useful, please cite:

@article{li2026saladpan,
  title={SALAD-Pan: Sensor-Agnostic Latent Adaptive Diffusion for Pan-Sharpening},
  author={Junjie Li and Congyang Ou and Haokui Zhang and Guoting Wei and Shengqin Jiang and Ying Li and Chunhua Shen},
  journal={arXiv preprint arXiv:2602.04473},
  year={2026}
}

Shoutouts

About

🤗 Official implementation for "SALAD-Pan: Sensor-Agnostic Latent Adaptive Diffusion for Pan-Sharpening" https://arxiv.org/abs/2602.04473

Topics

Resources

License

Stars

Watchers

Forks

Contributors

Languages