diff --git a/docs/zh/examples/confild.md b/docs/zh/examples/confild.md
new file mode 100644
index 0000000000..96e167182c
--- /dev/null
+++ b/docs/zh/examples/confild.md
@@ -0,0 +1,369 @@
+# AI辅助的时空湍流生成:条件神经场潜在扩散模型(CoNFILD)
+
+Distributed under a Creative Commons Attribution license 4.0 (CC BY).
+
+## 1. 背景简介
+### 1.1 论文信息
+| 年份 | 期刊 | 作者 | 引用数 | 论文PDF与补充材料 |
+|----------------|---------------------|--------------------------------------------------------------------------------------------------|--------|----------------------------------------------------------------------------------------------------|
+| 2024年1月3日 | Nature Communications | Pan Du, Meet Hemant Parikh, Xiantao Fan, Xin-Yang Liu, Jian-Xun Wang | 15 | [论文链接](https://doi.org/10.1038/s41467-024-54712-1)
[代码仓库](https://github.com/jx-wang-s-group/CoNFILD) |
+
+### 1.2 作者介绍
+- **通讯作者**:Jian-Xun Wang(王建勋)
所属机构:美国圣母大学航空航天与机械工程系、康奈尔大学机械与航空航天工程系
研究方向:湍流建模、生成式AI、物理信息机器学习
+
+- **其他作者**:
Pan Du、Meet Hemant Parikh(共同一作):圣母大学博士生,研究方向为生成式模型与计算流体力学
Xiantao Fan、Xin-Yang Liu:圣母大学研究助理,负责数值模拟与数据生成
+
+### 1.3 模型&复现代码
+| 问题类型 | 在线运行 | 神经网络架构 | 评估指标 |
+|------------------------|----------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------|
+| 时空湍流生成 | [aistudio](https://aistudio.baidu.com/projectdetail/8933946) | 条件神经场+潜在扩散模型 | MSE: 0.041(速度场) |
+
+=== "模型训练命令"
+```bash
+git clone https://github.com/PaddlePaddle/PaddleScience.git
+cd PaddleScience/examples/confild
+python confild.py mode=train
+```
+
+=== "预训练模型快速评估"
+
+``` sh
+python confild.py mode=eval
+```
+
+## 2. 问题定义
+### 2.1 研究背景
+湍流模拟在航空航天、海洋工程等领域至关重要,但传统方法如直接数值模拟(DNS)和大涡模拟(LES)计算成本高昂,难以应用于高雷诺数或实时场景。现有深度学习模型多基于确定性框架,难以捕捉湍流的混沌特性,且在复杂几何域中表现受限。
+
+### 2.2 核心挑战
+1. **高维数据**:三维时空湍流数据维度高达 \(O(10^9)\),传统生成模型内存需求巨大。
+2. **随机性建模**:需同时捕捉湍流的多尺度统计特性与瞬时动态。
+3. **几何适应性**:需支持不规则计算域与自适应网格。
+
+### 2.3 创新方法
+提出**条件神经场潜在扩散模型(CoNFILD)**,通过三阶段框架解决上述挑战:
+1. **神经场编码**:将高维流场压缩为低维潜在表示,压缩比达0.002%-0.017%。
+2. **潜在扩散**:在潜在空间进行概率扩散过程,学习湍流统计分布。
+3. **零样本条件生成**:结合贝叶斯推理,无需重新训练即可实现传感器重建、超分辨率等任务。
+
+
+*框架示意图:CNF编码器将流场映射到潜在空间,扩散模型生成新潜在样本,解码器重建物理场*
+
+## 3. 模型构建
+### 3.1 条件神经场(CNF)
+- **架构**:基于SIREN网络,采用正弦激活函数捕捉周期性特征。
+- **数学表示**:
+ $$
+ \mathscr{E}(\mathbf{X},\mathbf{L}) = \text{SIREN}(\mathbf{x}) + \text{FILM}(\mathbf{L})
+ $$
+ 其中FILM(Feature-wise Linear Modulation)通过潜在向量\(\mathbf{L}\)调节每层偏置。
+
+### 3.2 潜在扩散模型
+- **前向过程**:逐步添加高斯噪声,潜在表示\(\mathbf{z}_0 \rightarrow \mathbf{z}_T\)。
+- **逆向过程**:训练U-Net预测噪声,通过迭代去噪生成新样本:
+ $$
+ \mathbf{z}_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{z}_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(\mathbf{z}_t, t) \right) + \sigma_t \epsilon
+ $$
+
+### 3.3 零样本条件生成
+- **贝叶斯后验采样**:基于稀疏观测\(\Psi\),通过梯度修正潜在空间采样:
+ $$
+ \nabla_{\mathbf{z}_t} \log p(\mathbf{z}_t|\Psi) \approx \nabla_{\mathbf{z}_t} \log p(\Psi|\mathbf{z}_t) + \nabla_{\mathbf{z}_t} \log p(\mathbf{z}_t)
+ $$
+
+## 4. 问题求解
+### 4.1 数据集准备
+数据文件说明如下:
+```
+data # CNF的训练数据集
+|
+|-- data.npy # 要拟合的数据
+|
+|-- coords.npy # 查询坐标
+```
+
+在加载数据之后,需要进行normalization,以便于训练。具体代码如下:
+```python
+class Normalizer_ts(object):
+ def __init__(self, params=[], method="-11", dim=None):
+ self.params = params
+ self.method = method
+ self.dim = dim
+
+ def fit_normalize(self, data):
+ assert type(data) == paddle.Tensor
+ if len(self.params) == 0:
+ if self.method == "-11" or self.method == "01":
+ if self.dim is None:
+ self.params = paddle.max(x=data), paddle.min(x=data)
+ else:
+ self.params = (
+ paddle.max(keepdim=True, x=data, axis=self.dim),
+ paddle.argmax(keepdim=True, x=data, axis=self.dim),
+ )[0], (
+ paddle.min(keepdim=True, x=data, axis=self.dim),
+ paddle.argmin(keepdim=True, x=data, axis=self.dim),
+ )[
+ 0
+ ]
+ elif self.method == "ms":
+ if self.dim is None:
+ self.params = paddle.mean(x=data, axis=self.dim), paddle.std(
+ x=data, axis=self.dim
+ )
+ else:
+ self.params = paddle.mean(
+ x=data, axis=self.dim, keepdim=True
+ ), paddle.std(x=data, axis=self.dim, keepdim=True)
+ elif self.method == "none":
+ self.params = None
+ return self.fnormalize(data, self.params, self.method)
+
+ def normalize(self, new_data):
+ if not new_data.place == self.params[0].place:
+ self.params = self.params[0].to(new_data.place), self.params[1].to(
+ new_data.place
+ )
+ return self.fnormalize(new_data, self.params, self.method)
+
+ def denormalize(self, new_data_norm):
+ if not new_data_norm.place == self.params[0].place:
+ self.params = self.params[0].to(new_data_norm.place), self.params[1].to(
+ new_data_norm.place
+ )
+ return self.fdenormalize(new_data_norm, self.params, self.method)
+
+ def get_params(self):
+ if self.method == "ms":
+ print("returning mean and std")
+ elif self.method == "01":
+ print("returning max and min")
+ elif self.method == "-11":
+ print("returning max and min")
+ elif self.method == "none":
+ print("do nothing")
+ return self.params
+
+ @staticmethod
+ def fnormalize(data, params, method):
+ if method == "-11":
+ return (data - params[1].to(data.place)) / (
+ params[0].to(data.place) - params[1].to(data.place)
+ ) * 2 - 1
+ elif method == "01":
+ return (data - params[1].to(data.place)) / (
+ params[0].to(data.place) - params[1].to(data.place)
+ )
+ elif method == "ms":
+ return (data - params[0].to(data.place)) / params[1].to(data.place)
+ elif method == "none":
+ return data
+
+ @staticmethod
+ def fdenormalize(data_norm, params, method):
+ if method == "-11":
+ return (data_norm + 1) / 2 * (
+ params[0].to(data_norm.place) - params[1].to(data_norm.place)
+ ) + params[1].to(data_norm.place)
+ elif method == "01":
+ return data_norm * (
+ params[0].to(data_norm.place) - params[1].to(data_norm.place)
+ ) + params[1].to(data_norm.place)
+ elif method == "ms":
+ return data_norm * params[1].to(data_norm.place) + params[0].to(
+ data_norm.place
+ )
+ elif method == "none":
+ return data_norm
+```
+
+### 4.2 CoNFiLD 模型
+CoNFiLD 模型基于贝叶斯后验采样,将稀疏传感器测量数据作为条件输入。通过训练好的无条件扩散模型作为先验,在扩散后验采样过程中,考虑测量噪声引入的不确定性。利用状态到观测映射,根据条件向量与流场的关系,通过调整无条件得分函数,引导生成与传感器数据一致的全时空流场实现重构,并且能提供重构的不确定性估计。代码如下:
+
+```python
+class SIRENAutodecoder_film(paddle.nn.Layer):
+ """
+ siren network with author decoding
+
+ Args:
+ input_keys (Tuple[str,...], optional): Key to get the input tensor from the dict.
+ output_keys (Tuple[str,...], optional): Key to save the output tensor into the dict.
+ in_coord_features (int, optional): Number of input coordinates features
+ in_latent_features (int, optional): Number of input latent features
+ out_features (int, optional): Number of output features
+ num_hidden_layers (int, optional): Number of hidden layers
+ hidden_features (int, optional): Number of hidden features
+ outermost_linear (bool, optional): Whether to use linear layer at the end. Defaults to False.
+ nonlinearity (str, optional): Nonlinearity to use. Defaults to "sine".
+ weight_init (Callable, optional): Weight initialization function. Defaults to None.
+ bias_init (Callable, optional): Bias initialization function. Defaults to None.
+ premap_mode (str, optional): Feature mapping mode. Defaults to None.
+
+ Examples:
+ >>> model = ppsci.arch.SIRENAutodecoder_film(
+ input_keys=["input1", "input2"],
+ output_keys=("output",),
+ in_coord_features=2,
+ in_latent_features=128,
+ out_features=3,
+ num_hidden_layers=10,
+ hidden_features=128,
+ )
+ >>> input_data = {"input1": paddle.randn([10, 2]), "input2": paddle.randn([10, 128])}
+ >>> out_dict = model(input_data)
+ >>> for k, v in out_dict.items():
+ ... print(k, v.shape)
+ output [22, 918, 3]
+ """
+
+ def __init__(
+ self,
+ input_keys,
+ output_keys,
+ in_coord_features,
+ in_latent_features,
+ out_features,
+ num_hidden_layers,
+ hidden_features,
+ outermost_linear=False,
+ nonlinearity="sine",
+ weight_init=None,
+ bias_init=None,
+ premap_mode=None,
+ **kwargs,
+ ):
+ super().__init__()
+ self.input_keys = input_keys
+ self.output_keys = output_keys
+
+ self.premap_mode = premap_mode
+ if self.premap_mode is not None:
+ self.premap_layer = FeatureMapping(
+ in_coord_features, mode=premap_mode, **kwargs
+ )
+ in_coord_features = self.premap_layer.dim
+ self.first_layer_init = None
+ self.nl, nl_weight_init, first_layer_init = NLS_AND_INITS[nonlinearity]
+ if weight_init is not None:
+ self.weight_init = weight_init
+ else:
+ self.weight_init = nl_weight_init
+ self.net1 = paddle.nn.LayerList(
+ sublayers=[BatchLinear(in_coord_features, hidden_features)]
+ + [
+ BatchLinear(hidden_features, hidden_features)
+ for i in range(num_hidden_layers)
+ ]
+ + [BatchLinear(hidden_features, out_features)]
+ )
+ self.net2 = paddle.nn.LayerList(
+ sublayers=[
+ BatchLinear(in_latent_features, hidden_features, bias_attr=False)
+ for i in range(num_hidden_layers + 1)
+ ]
+ )
+ if self.weight_init is not None:
+ self.net1.apply(self.weight_init)
+ self.net2.apply(self.weight_init)
+ if first_layer_init is not None:
+ self.net1[0].apply(first_layer_init)
+ self.net2[0].apply(first_layer_init)
+ if bias_init is not None:
+ self.net2.apply(bias_init)
+
+ def forward(self, input_data):
+ coords = input_data[self.input_keys[0]]
+ latents = input_data[self.input_keys[1]]
+ if self.premap_mode is not None:
+ x = self.premap_layer(coords)
+ else:
+ x = coords
+
+ for i in range(len(self.net1) - 1):
+ x = self.net1[i](x) + self.net2[i](latents)
+ x = self.nl(x)
+ x = self.net1[-1](x)
+ return {self.output_keys[0]: x}
+
+ def disable_gradient(self):
+ for param in self.parameters():
+ param.stop_gradient = not False
+```
+为了在计算时,准确快速地访问具体变量的值,我们在这里指定网络模型的输入变量名是 ["confild_x", "latent_z"],输出变量名是 ["confild_output"],这些命名与后续代码保持一致。
+
+4.3 模型训练、评估
+完成上述设置之后,只需要将上述实例化的对象按照文档进行组合,然后启动训练、评估。
+```python
+def signal_train(cfg, normed_coords, normed_fois, spatio_axis, out_normalizer):
+ cnf_model = SIRENAutodecoder_film(**cfg.CONFILD)
+ latents_model = LatentContainer(**cfg.Latent)
+
+ dataset = basic_set(normed_fois, normed_coords)
+ criterion = paddle.nn.MSELoss()
+
+ # set loader
+ train_loader = DataLoader(
+ dataset=dataset, batch_size=cfg.TRAIN.batch_size, shuffle=True
+ )
+ test_loader = DataLoader(
+ dataset=dataset, batch_size=cfg.TRAIN.test_batch_size, shuffle=False
+ )
+ # set optimizer
+ cnf_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.cnf, weight_decay=0.0)(cnf_model)
+ latents_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.latents, weight_decay=0.0)(
+ latents_model
+ )
+
+ for i in range(cfg.TRAIN.epochs):
+ cnf_model.train()
+ latents_model.train()
+ if i != 0:
+ cnf_optimizer.step()
+ cnf_optimizer.clear_grad(set_to_zero=False)
+ train_loss = []
+ for batch_coords, batch_fois, idx in train_loader:
+ idx = {"latent_x": idx}
+ batch_latent = latents_model(idx)
+ if isinstance(batch_coords, list):
+ batch_coords = [i for i in batch_coords]
+ data = {
+ "confild_x": batch_coords,
+ "latent_z": batch_latent["latent_z"],
+ }
+ batch_output = cnf_model(data)
+ loss = criterion(batch_output["confild_output"], batch_fois)
+ latents_optimizer.clear_grad(set_to_zero=False)
+ loss.backward()
+ latents_optimizer.step()
+ train_loss.append(loss.item())
+ epoch_loss = paddle.stack(x=train_loss).mean()
+ print("epoch {}, train loss {}".format(i + 1, epoch_loss))
+ if i % 100 == 0:
+ test_error = []
+ cnf_model.eval()
+ latents_model.eval()
+ with paddle.no_grad():
+ for test_coords, test_fois, idx in test_loader:
+ if isinstance(test_coords, list):
+ test_coords = [i for i in test_coords]
+ prediction = out_normalizer.denormalize(
+ cnf_model(
+ {
+ "confild_x": test_coords,
+ "latent_z": latents_model({"latent_x": idx})[
+ "latent_z"
+ ],
+ }
+ )
+ )
+ target = out_normalizer.denormalize(test_fois)
+ error = rMAE(prediction=prediction, target=target, dims=spatio_axis)
+ test_error.append(error)
+ test_error = paddle.concat(x=test_error).mean(axis=0)
+ print("test MAE: ", test_error)
+ if i % 1000 == 0:
+ paddle.save(cnf_model.state_dict(), f"cnf_model_{i}.pdparams")
+ paddle.save(latents_model.state_dict(), f"latents_model_{i}.pdparams")
+```
+
+## 5. 实验结果
diff --git a/examples/confild/conf/confild_case1.yaml b/examples/confild/conf/confild_case1.yaml
new file mode 100644
index 0000000000..4bcec2db32
--- /dev/null
+++ b/examples/confild/conf/confild_case1.yaml
@@ -0,0 +1,123 @@
+defaults:
+ - ppsci_default
+ - TRAIN: train_default
+ - TRAIN/ema: ema_default
+ - TRAIN/swa: swa_default
+ - EVAL: eval_default
+ - INFER: infer_default
+ - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
+ - _self_
+
+hydra:
+ run:
+ # dynamic output directory according to running time and override name
+ # dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
+ dir: ./outputs_confild_case1
+ job:
+ name: ${mode} # name of logfile
+ chdir: false # keep current working directory unchanged
+ callbacks:
+ init_callback:
+ _target_: ppsci.utils.callbacks.InitCallback
+ sweep:
+ # output directory for multirun
+ dir: ${hydra.run.dir}
+ subdir: ./
+
+# general settings
+mode: infer # running mode: infer
+seed: 2025
+output_dir: ${hydra:run.dir}
+log_freq: 20
+
+TRAIN:
+ batch_size: 64
+ test_batch_size: 256
+ epochs: 9800
+ mutil_GPU: 1
+ lr:
+ cnf: 1.e-4
+ latents: 1.e-5
+
+EVAL:
+ confild_pretrained_model_path: ./outputs_confild_case1/confild_case1/epoch_99999
+ latent_pretrained_model_path: ./outputs_confild_case1/latent_case1/epoch_99999
+
+CONFILD:
+ input_keys: ["confild_x", "latent_z"]
+ output_keys: ["confild_output"]
+ num_hidden_layers: 10
+ out_features: 3
+ hidden_features: 128
+ in_coord_features: 2
+ in_latent_features: 128
+
+Latent:
+ input_keys: ["latent_x"]
+ output_keys: ["latent_z"]
+ N_samples: 16000
+ lumped: True
+ N_features: 128
+ dims: 2
+
+INFER:
+ Latent:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/latent_case1
+ pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Latent.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ log_freq: 20
+ Confild:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/confild_case1
+ pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Confild.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ coord_shape: [918, 2]
+ latents_shape: [1, 128]
+ log_freq: 20
+ batch_size: 64
+
+Uncondiction_INFER:
+ batch_size : 16
+ test_batch_size : 16
+ time_length : 128
+ latent_length : 128
+ image_size : 128
+ num_channels: 128
+ num_res_blocks: 2
+ num_heads: 4
+ num_head_channels: 64
+ attention_resolutions: "32,16,8"
+ channel_mult: null
+ steps: 1000
+ noise_schedule: "cosine"
+
+Data:
+ data_path: /home/aistudio/work/extracted/data/Case1/data.npy
+ coor_path: /home/aistudio/work/extracted/data/Case1/coords.npy
+ normalizer:
+ method: "-11"
+ dim: 0
+ load_data_fn: load_elbow_flow
diff --git a/examples/confild/conf/confild_case2.yaml b/examples/confild/conf/confild_case2.yaml
new file mode 100644
index 0000000000..0ab1da60a4
--- /dev/null
+++ b/examples/confild/conf/confild_case2.yaml
@@ -0,0 +1,108 @@
+defaults:
+ - ppsci_default
+ - TRAIN: train_default
+ - TRAIN/ema: ema_default
+ - TRAIN/swa: swa_default
+ - EVAL: eval_default
+ - INFER: infer_default
+ - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
+ - _self_
+
+hydra:
+ run:
+ # dynamic output directory according to running time and override name
+ # dir: outputs_confild_case2/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
+ dir: ./outputs_confild_case2
+ job:
+ name: ${mode} # name of logfile
+ chdir: false # keep current working directory unchanged
+ callbacks:
+ init_callback:
+ _target_: ppsci.utils.callbacks.InitCallback
+ sweep:
+ # output directory for multirun
+ dir: ${hydra.run.dir}
+ subdir: ./
+
+# general settings
+mode: infer # running mode: infer
+seed: 2025
+output_dir: ${hydra:run.dir}
+log_freq: 20
+
+TRAIN:
+ batch_size: 40
+ test_batch_size: 40
+ epochs: 44500
+ mutil_GPU: 1
+ lr:
+ cnf: 1.e-4
+ latents: 1.e-5
+
+EVAL:
+ confild_pretrained_model_path: ./outputs_confild_case2/confild_case2/epoch_99999
+ latent_pretrained_model_path: ./outputs_confild_case2/latent_case2/epoch_99999
+
+CONFILD:
+ input_keys: ["confild_x", "latent_z"]
+ output_keys: ["confild_output"]
+ num_hidden_layers: 10
+ out_features: 4
+ hidden_features: 256
+ in_coord_features: 2
+ in_latent_features: 256
+
+Latent:
+ input_keys: ["latent_x"]
+ output_keys: ["latent_z"]
+ N_samples: 1200
+ lumped: False
+ N_features: 256
+ dims: 2
+
+INFER:
+ Latent:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/latent_case2
+ pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Latent.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ log_freq: 20
+ Confild:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/confild_case2
+ pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Confild.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ coord_shape: [400, 100, 2]
+ latents_shape: [1, 1, 256]
+ log_freq: 20
+ batch_size: 40
+
+Data:
+ data_path: /home/aistudio/work/extracted/data/Case2/data.npy
+ # coor_path: ../case2/coords.npy
+ normalizer:
+ method: "-11"
+ dim: 0
+ load_data_fn: load_channel_flow
diff --git a/examples/confild/conf/confild_case3.yaml b/examples/confild/conf/confild_case3.yaml
new file mode 100644
index 0000000000..9369a7e944
--- /dev/null
+++ b/examples/confild/conf/confild_case3.yaml
@@ -0,0 +1,108 @@
+defaults:
+ - ppsci_default
+ - TRAIN: train_default
+ - TRAIN/ema: ema_default
+ - TRAIN/swa: swa_default
+ - EVAL: eval_default
+ - INFER: infer_default
+ - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
+ - _self_
+
+hydra:
+ run:
+ # dynamic output directory according to running time and override name
+ # dir: outputs_confild_case3/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
+ dir: ./outputs_confild_case3
+ job:
+ name: ${mode} # name of logfile
+ chdir: false # keep current working directory unchanged
+ callbacks:
+ init_callback:
+ _target_: ppsci.utils.callbacks.InitCallback
+ sweep:
+ # output directory for multirun
+ dir: ${hydra.run.dir}
+ subdir: ./
+
+# general settings
+mode: infer # running mode: infer
+seed: 2025
+output_dir: ${hydra:run.dir}
+log_freq: 20
+
+TRAIN:
+ batch_size: 100
+ test_batch_size: 100
+ epochs: 4800
+ mutil_GPU: 2
+ lr:
+ cnf: 1.e-4
+ latents: 1.e-5
+
+EVAL:
+ confild_pretrained_model_path: ./outputs_confild_case3/confild_case3/epoch_99999
+ latent_pretrained_model_path: ./outputs_confild_case3/latent_case3/epoch_99999
+
+CONFILD:
+ input_keys: ["confild_x", "latent_z"]
+ output_keys: ["confild_output"]
+ num_hidden_layers: 117
+ out_features: 2
+ hidden_features: 256
+ in_coord_features: 2
+ in_latent_features: 256
+
+Latent:
+ input_keys: ["latent_x"]
+ output_keys: ["latent_z"]
+ N_samples: 2880
+ lumped: True
+ N_features: 256
+ dims: 2
+
+INFER:
+ Latent:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/latent_case3
+ pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Latent.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ log_freq: 20
+ Confild:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/confild_case3
+ pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Confild.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ coord_shape: [10884, 2]
+ latents_shape: [1, 256]
+ log_freq: 20
+ batch_size: 100
+
+Data:
+ data_path: /home/aistudio/work/extracted/data/Case3/data.npy
+ coor_path: /home/aistudio/work/extracted/data/Case3/coords.npy
+ normalizer:
+ method: "-11"
+ dim: 0
+ load_data_fn: load_periodic_hill_flow
diff --git a/examples/confild/conf/confild_case4.yaml b/examples/confild/conf/confild_case4.yaml
new file mode 100644
index 0000000000..820c5a92dc
--- /dev/null
+++ b/examples/confild/conf/confild_case4.yaml
@@ -0,0 +1,108 @@
+defaults:
+ - ppsci_default
+ - TRAIN: train_default
+ - TRAIN/ema: ema_default
+ - TRAIN/swa: swa_default
+ - EVAL: eval_default
+ - INFER: infer_default
+ - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
+ - _self_
+
+hydra:
+ run:
+ # dynamic output directory according to running time and override name
+ # dir: outputs_confild_case4/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
+ dir: ./outputs_confild_case4
+ job:
+ name: ${mode} # name of logfile
+ chdir: false # keep current working directory unchanged
+ callbacks:
+ init_callback:
+ _target_: ppsci.utils.callbacks.InitCallback
+ sweep:
+ # output directory for multirun
+ dir: ${hydra.run.dir}
+ subdir: ./
+
+# general settings
+mode: infer # running mode: infer
+seed: 2025
+output_dir: ${hydra:run.dir}
+log_freq: 20
+
+TRAIN:
+ batch_size: 4
+ test_batch_size: 4
+ epochs: 20000
+ mutil_GPU: 2
+ lr:
+ cnf: 1.e-4
+ latents: 1.e-5
+
+EVAL:
+ confild_pretrained_model_path: ./outputs_confild_case4/confild_case4/epoch_99999
+ latent_pretrained_model_path: ./outputs_confild_case4/latent_case4/epoch_99999
+
+CONFILD:
+ input_keys: ["confild_x", "latent_z"]
+ output_keys: ["confild_output"]
+ num_hidden_layers: 15
+ out_features: 3
+ hidden_features: 384
+ in_coord_features: 3
+ in_latent_features: 384
+
+Latent:
+ input_keys: ["latent_x"]
+ output_keys: ["latent_z"]
+ N_samples: 1200
+ lumped: True
+ N_features: 384
+ dims: 3
+
+INFER:
+ Latent:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/latent_case4
+ pdmodel_path: ${INFER.Latent.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Latent.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Latent.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ log_freq: 20
+ Confild:
+ INFER:
+ pretrained_model_path: null
+ export_path: ./inference/confild_case4
+ pdmodel_path: ${INFER.Confild.INFER.export_path}.pdmodel
+ pdiparams_path: ${INFER.Confild.INFER.export_path}.pdiparams
+ onnx_path: ${INFER.Confild.INFER.export_path}.onnx
+ device: gpu
+ engine: native
+ precision: fp32
+ ir_optim: true
+ min_subgraph_size: 5
+ gpu_mem: 2000
+ gpu_id: 0
+ max_batch_size: 1024
+ num_cpu_threads: 10
+ coord_shape: [58483, 3]
+ latents_shape: [1, 384]
+ log_freq: 20
+ batch_size: 4
+
+Data:
+ data_path: /home/aistudio/work/extracted/data/Case4/data.npy
+ coor_path: /home/aistudio/work/extracted/data/Case4/coords.npy
+ normalizer:
+ method: "-11"
+ dim: 0
+ load_data_fn: load_3d_flow
diff --git a/examples/confild/conf/un_confild_case1.yaml b/examples/confild/conf/un_confild_case1.yaml
new file mode 100644
index 0000000000..76b0c8f27a
--- /dev/null
+++ b/examples/confild/conf/un_confild_case1.yaml
@@ -0,0 +1,92 @@
+defaults:
+ - ppsci_default
+ - TRAIN: train_default
+ - TRAIN/ema: ema_default
+ - TRAIN/swa: swa_default
+ - EVAL: eval_default
+ - INFER: infer_default
+ - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
+ - _self_
+
+hydra:
+ run:
+ # dynamic output directory according to running time and override name
+ # dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
+ dir: ./outputs_un_confild_case1
+ job:
+ name: ${mode} # name of logfile
+ chdir: false # keep current working directory unchanged
+ callbacks:
+ init_callback:
+ _target_: ppsci.utils.callbacks.InitCallback
+ sweep:
+ # output directory for multirun
+ dir: ${hydra.run.dir}
+ subdir: ./
+
+# general settings
+mode: eval # running mode: infer
+seed: 2025
+output_dir: ${hydra:run.dir}
+log_freq: 20
+save_path: ${output_dir}/result.npy
+
+TRAIN:
+ batch_size : 16
+ test_batch_size : 16
+ ema_rate: "0.9999"
+ lr_anneal_steps: 0
+ lr : 5.e-5
+ weight_decay: 0.0
+ final_lr: 0.
+ microbatch: -1
+
+EVAL:
+ mutil_GPU: 1
+ lr : 5.e-5
+ ema_rate: "0.9999"
+ log_interval: 1000
+ save_interval: 10000
+ lr_anneal_steps: 0
+ time_length : 128
+ latent_length : 128
+ test_batch_size: 16
+
+UNET:
+ image_size : 128
+ num_channels: 128
+ num_res_blocks: 2
+ num_heads: 4
+ num_head_channels: 64
+ attention_resolutions: "32,16,8"
+ channel_mult: null
+ ema_path: /home/aistudio/work/extracted/data/Case1/diffusion/ema.pdparams
+
+Diff:
+ steps: 1000
+ noise_schedule: "cosine"
+
+CNF:
+ mutil_GPU: 1
+ data_path: /home/aistudio/work/extracted/data/Case1/data.npy
+ coor_path: /home/aistudio/work/extracted/data/Case1/coords.npy
+ load_data_fn: load_elbow_flow
+ normalizer:
+ method: "-11"
+ dim: 0
+ CONFILD:
+ input_keys: ["confild_x", "latent_z"]
+ output_keys: ["confild_output"]
+ num_hidden_layers: 10
+ out_features: 3
+ hidden_features: 128
+ in_coord_features: 2
+ in_latent_features: 128
+ normalizer_params_path: /home/aistudio/work/extracted/data/Case1/cnf/normalizer_params.pdparams
+ model_path: ./outputs_confild_case1/confild_case1/epoch_99999
+
+DATA:
+ max_val: 1.0
+ min_val: -1.0
+ train_data: "/home/aistudio/work/extracted/data/Case1/train_data.npy"
+ valid_data: "/home/aistudio/work/extracted/data/Case1/valid_data.npy"
\ No newline at end of file
diff --git a/examples/confild/conf/un_confild_case2.yaml b/examples/confild/conf/un_confild_case2.yaml
new file mode 100644
index 0000000000..641a10b9d7
--- /dev/null
+++ b/examples/confild/conf/un_confild_case2.yaml
@@ -0,0 +1,92 @@
+defaults:
+ - ppsci_default
+ - TRAIN: train_default
+ - TRAIN/ema: ema_default
+ - TRAIN/swa: swa_default
+ - EVAL: eval_default
+ - INFER: infer_default
+ - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
+ - _self_
+
+hydra:
+ run:
+ # dynamic output directory according to running time and override name
+ # dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
+ dir: ./outputs_un_confild_case2
+ job:
+ name: ${mode} # name of logfile
+ chdir: false # keep current working directory unchanged
+ callbacks:
+ init_callback:
+ _target_: ppsci.utils.callbacks.InitCallback
+ sweep:
+ # output directory for multirun
+ dir: ${hydra.run.dir}
+ subdir: ./
+
+# general settings
+mode: eval # running mode: infer
+seed: 2025
+output_dir: ${hydra:run.dir}
+log_freq: 20
+save_path: ${output_dir}/result.npy
+
+TRAIN:
+ batch_size : 16
+ test_batch_size : 16
+ ema_rate: "0.9999"
+ lr_anneal_steps: 0
+ lr : 5.e-5
+ weight_decay: 0.0
+ final_lr: 0.
+ microbatch: -1
+
+EVAL:
+ mutil_GPU: 1
+ lr : 5.e-5
+ ema_rate: "0.9999"
+ log_interval: 1000
+ save_interval: 10000
+ lr_anneal_steps: 0
+ time_length : 256
+ latent_length : 256
+ test_batch_size: 16
+
+UNET:
+ image_size : 256
+ num_channels: 128
+ num_res_blocks: 2
+ num_heads: 4
+ num_head_channels: 64
+ attention_resolutions: "32,16,8"
+ channel_mult: null
+ ema_path: /home/aistudio/work/extracted/data/Case2/diffusion/ema.pdparams
+
+Diff:
+ steps: 1000
+ noise_schedule: "cosine"
+
+CNF:
+ mutil_GPU: 1
+ data_path: /home/aistudio/work/extracted/data/Case2/data.npy
+ # coor_path: ../Case2/coords.npy
+ load_data_fn: load_channel_flow
+ normalizer:
+ method: "-11"
+ dim: 0
+ CONFILD:
+ input_keys: ["confild_x", "latent_z"]
+ output_keys: ["confild_output"]
+ num_hidden_layers: 10
+ out_features: 4
+ hidden_features: 256
+ in_coord_features: 2
+ in_latent_features: 256
+ normalizer_params_path: /home/aistudio/work/extracted/data/Case2/cnf/normalizer_params.pdparams
+ model_path: ./outputs_confild_case2/confild_case2/epoch_99999
+
+DATA:
+ max_val: 1.0
+ min_val: -1.0
+ train_data: "/home/aistudio/work/extracted/data/Case2/train_data.npy"
+ valid_data: "/home/aistudio/work/extracted/data/Case2/valid_data.npy"
\ No newline at end of file
diff --git a/examples/confild/conf/un_confild_case3.yaml b/examples/confild/conf/un_confild_case3.yaml
new file mode 100644
index 0000000000..65be7105db
--- /dev/null
+++ b/examples/confild/conf/un_confild_case3.yaml
@@ -0,0 +1,92 @@
+defaults:
+ - ppsci_default
+ - TRAIN: train_default
+ - TRAIN/ema: ema_default
+ - TRAIN/swa: swa_default
+ - EVAL: eval_default
+ - INFER: infer_default
+ - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
+ - _self_
+
+hydra:
+ run:
+ # dynamic output directory according to running time and override name
+ # dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
+ dir: ./outputs_un_confild_case3
+ job:
+ name: ${mode} # name of logfile
+ chdir: false # keep current working directory unchanged
+ callbacks:
+ init_callback:
+ _target_: ppsci.utils.callbacks.InitCallback
+ sweep:
+ # output directory for multirun
+ dir: ${hydra.run.dir}
+ subdir: ./
+
+# general settings
+mode: eval # running mode: infer
+seed: 2025
+output_dir: ${hydra:run.dir}
+log_freq: 20
+save_path: ${output_dir}/result.npy
+
+TRAIN:
+ batch_size : 16
+ test_batch_size : 16
+ ema_rate: "0.9999"
+ lr_anneal_steps: 0
+ lr : 5.e-5
+ weight_decay: 0.0
+ final_lr: 0.
+ microbatch: -1
+
+EVAL:
+ mutil_GPU: 2
+ lr : 5.e-5
+ ema_rate: "0.9999"
+ log_interval: 1000
+ save_interval: 10000
+ lr_anneal_steps: 0
+ time_length : 256
+ latent_length : 256
+ test_batch_size: 16
+
+UNET:
+ image_size : 256
+ num_channels: 128
+ num_res_blocks: 2
+ num_heads: 4
+ num_head_channels: 64
+ attention_resolutions: "32,16,8"
+ channel_mult: null
+ ema_path: /home/aistudio/work/extracted/data/Case3/diffusion/ema.pdparams
+
+Diff:
+ steps: 1000
+ noise_schedule: "cosine"
+
+CNF:
+ mutil_GPU: 1
+ data_path: /home/aistudio/work/extracted/data/Case3/data.npy
+ coor_path: /home/aistudio/work/extracted/data/Case3/coords.npy
+ normalizer:
+ method: "-11"
+ dim: 0
+ load_data_fn: load_periodic_hill_flow
+ CONFILD:
+ input_keys: ["confild_x", "latent_z"]
+ output_keys: ["confild_output"]
+ num_hidden_layers: 117
+ out_features: 2
+ hidden_features: 256
+ in_coord_features: 2
+ in_latent_features: 256
+ normalizer_params_path: /home/aistudio/work/extracted/data/Case3/cnf/normalizer_params.pdparams
+ model_path: ./outputs_confild_case3/confild_case3/epoch_99999
+
+DATA:
+ min_val: -1.0
+ max_val: 1.0
+ train_data: "/home/aistudio/work/extracted/data/Case3/train_data.npy"
+ valid_data: "/home/aistudio/work/extracted/data/Case3/valid_data.npy"
diff --git a/examples/confild/conf/un_confild_case4.yaml b/examples/confild/conf/un_confild_case4.yaml
new file mode 100644
index 0000000000..9d5c5dd3b8
--- /dev/null
+++ b/examples/confild/conf/un_confild_case4.yaml
@@ -0,0 +1,93 @@
+defaults:
+ - ppsci_default
+ - TRAIN: train_default
+ - TRAIN/ema: ema_default
+ - TRAIN/swa: swa_default
+ - EVAL: eval_default
+ - INFER: infer_default
+ - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
+ - _self_
+
+hydra:
+ run:
+ # dynamic output directory according to running time and override name
+ # dir: outputs_confild_case1/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
+ dir: ./outputs_un_confild_case4
+ job:
+ name: ${mode} # name of logfile
+ chdir: false # keep current working directory unchanged
+ callbacks:
+ init_callback:
+ _target_: ppsci.utils.callbacks.InitCallback
+ sweep:
+ # output directory for multirun
+ dir: ${hydra.run.dir}
+ subdir: ./
+
+# general settings
+mode: eval # running mode: infer
+seed: 2025
+output_dir: ${hydra:run.dir}
+log_freq: 20
+save_path: ${output_dir}/result.npy
+
+TRAIN:
+ batch_size : 8
+ test_batch_size : 8
+ ema_rate: "0.9999"
+ lr_anneal_steps: 0
+ lr : 5.e-5
+ weight_decay: 0.0
+ final_lr: 0.
+ microbatch: -1
+
+EVAL:
+ mutil_GPU: 2
+ lr : 5.e-5
+ ema_rate: "0.9999"
+ log_interval: 1000
+ save_interval: 10000
+ lr_anneal_steps: 0
+ time_length : 256
+ latent_length : 256
+ test_batch_size: 8
+
+UNET:
+
+ image_size : 384
+ num_channels: 128
+ num_res_blocks: 2
+ num_heads: 4
+ num_head_channels: 64
+ attention_resolutions: "32,16,8"
+ channel_mult: "1, 1, 2, 2, 4, 4"
+ ema_path: /home/aistudio/work/extracted/data/Case4/diffusion/ema.pdparams
+
+Diff:
+ steps: 1000
+ noise_schedule: "cosine"
+
+CNF:
+ mutil_GPU: 1
+ data_path: /home/aistudio/work/extracted/data/Case4/data.npy
+ coor_path: /home/aistudio/work/extracted/data/Case4/coords.npy
+ normalizer:
+ method: "-11"
+ dim: 0
+ load_data_fn: load_3d_flow
+ CONFILD:
+ input_keys: ["confild_x", "latent_z"]
+ output_keys: ["confild_output"]
+ num_hidden_layers: 15
+ out_features: 3
+ hidden_features: 384
+ in_coord_features: 3
+ in_latent_features: 384
+ normalizer_params_path: /home/aistudio/work/extracted/data/Case4/cnf/normalizer_params.pdparams
+ model_path: ./outputs_confild_case4/confild_case4/epoch_99999
+
+DATA:
+ min_val: -1.0
+ max_val: 1.0
+ train_data: "/home/aistudio/work/extracted/data/Case4/train_data.npy"
+ valid_data: "/home/aistudio/work/extracted/data/Case4/valid_data.npy"
diff --git a/examples/confild/confild.py b/examples/confild/confild.py
new file mode 100644
index 0000000000..8003010630
--- /dev/null
+++ b/examples/confild/confild.py
@@ -0,0 +1,593 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import enum
+import math
+import hydra
+import matplotlib.pyplot as plt
+import numpy as np
+import paddle
+from omegaconf import DictConfig
+from paddle.distributed import fleet
+from paddle.io import DataLoader
+from paddle.io import DistributedBatchSampler
+
+import ppsci
+from ppsci.arch import UNetModel
+from ppsci.arch import LatentContainer
+from ppsci.arch import SIRENAutodecoder_film
+from ppsci.arch import SpacedDiffusion
+from ppsci.arch import ModelVarType
+from ppsci.arch import ModelMeanType
+from ppsci.utils import logger
+
+
+def load_elbow_flow(path):
+ return np.load(f"{path}")[1:]
+
+
+def load_channel_flow(
+ path,
+ t_start=0,
+ t_end=1200,
+ t_every=1,
+):
+ return np.load(f"{path}")[t_start:t_end:t_every]
+
+
+def load_periodic_hill_flow(path):
+ data = np.load(f"{path}")
+ return data
+
+
+def load_3d_flow(path):
+ data = np.load(f"{path}")
+ return data
+
+
+def rMAE(prediction, target, dims=(1, 2)):
+ return paddle.abs(x=prediction - target).mean(axis=dims) / paddle.abs(
+ x=target
+ ).mean(axis=dims)
+
+
+class Normalizer_ts(object):
+ def __init__(self, params=[], method="-11", dim=None):
+ self.params = params
+ self.method = method
+ self.dim = dim
+
+ def fit_normalize(self, data):
+ assert type(data) == paddle.Tensor
+ if len(self.params) == 0:
+ if self.method == "-11" or self.method == "01":
+ if self.dim is None:
+ self.params = paddle.max(x=data), paddle.min(x=data)
+ else:
+ self.params = (
+ paddle.max(keepdim=True, x=data, axis=self.dim),
+ paddle.argmax(keepdim=True, x=data, axis=self.dim),
+ )[0], (
+ paddle.min(keepdim=True, x=data, axis=self.dim),
+ paddle.argmin(keepdim=True, x=data, axis=self.dim),
+ )[
+ 0
+ ]
+ elif self.method == "ms":
+ if self.dim is None:
+ self.params = paddle.mean(x=data, axis=self.dim), paddle.std(
+ x=data, axis=self.dim
+ )
+ else:
+ self.params = paddle.mean(
+ x=data, axis=self.dim, keepdim=True
+ ), paddle.std(x=data, axis=self.dim, keepdim=True)
+ elif self.method == "none":
+ self.params = None
+ return self.fnormalize(data, self.params, self.method)
+
+ def normalize(self, new_data):
+ if not new_data.place == self.params[0].place:
+ self.params = self.params[0], self.params[1]
+ return self.fnormalize(new_data, self.params, self.method)
+
+ def denormalize(self, new_data_norm):
+ if not new_data_norm.place == self.params[0].place:
+ self.params = self.params[0], self.params[1]
+ return self.fdenormalize(new_data_norm, self.params, self.method)
+
+ def get_params(self):
+ if self.method == "ms":
+ print("returning mean and std")
+ elif self.method == "01":
+ print("returning max and min")
+ elif self.method == "-11":
+ print("returning max and min")
+ elif self.method == "none":
+ print("do nothing")
+ return self.params
+
+ @staticmethod
+ def fnormalize(data, params, method):
+ if method == "-11":
+ return (data - params[1]) / (
+ params[0] - params[1]
+ ) * 2 - 1
+ elif method == "01":
+ return (data - params[1]) / (
+ params[0] - params[1]
+ )
+ elif method == "ms":
+ return (data - params[0]) / params[1]
+ elif method == "none":
+ return data
+
+ @staticmethod
+ def fdenormalize(data_norm, params, method):
+ if method == "-11":
+ return (data_norm + 1) / 2 * (params[0] - params[1]) + params[1]
+ elif method == "01":
+ return data_norm * (
+ params[0] - params[1]
+ ) + params[1]
+ elif method == "ms":
+ return data_norm * params[1] + params[0]
+ elif method == "none":
+ return data_norm
+
+
+class basic_set(paddle.io.Dataset):
+ def __init__(self, fois, coord, global_indices=None, extra_siren_in=None) -> None:
+ super().__init__()
+ self.fois = fois.numpy()
+ self.total_samples = tuple(fois.shape)[0]
+ self.coords = coord.numpy()
+ # 存储全局索引
+ self.global_indices = global_indices if global_indices is not None else np.arange(self.total_samples)
+
+ def __len__(self):
+ return self.total_samples
+
+ def __getitem__(self, idx):
+ # 使用全局索引
+ global_idx = self.global_indices[idx]
+ if hasattr(self, "extra_in"):
+ extra_id = idx % tuple(self.fois.shape)[1]
+ idb = idx // tuple(self.fois.shape)[1]
+ return (self.coords, self.extra_in[extra_id]), self.fois[idb, extra_id], global_idx
+ else:
+ return self.coords, self.fois[idx], global_idx
+
+
+# build data
+def getdata(cfg):
+ ###### read data - fois ######
+ if cfg.Data.load_data_fn == "load_3d_flow":
+ fois = load_3d_flow(cfg.Data.data_path)
+ elif cfg.Data.load_data_fn == "load_elbow_flow":
+ fois = load_elbow_flow(cfg.Data.data_path)
+ elif cfg.Data.load_data_fn == "load_channel_flow":
+ fois = load_channel_flow(cfg.Data.data_path)
+ elif cfg.Data.load_data_fn == "load_periodic_hill_flow":
+ fois = load_periodic_hill_flow(cfg.Data.data_path)
+ else:
+ fois = np.load(cfg.Data.data_path)
+
+ # 计算空间形状和轴
+ spatio_shape = fois.shape[1:-1]
+ spatio_axis = list(
+ range(
+ fois.ndim if isinstance(fois, np.ndarray) else fois.dim()
+ )
+ )[1:-1]
+
+ ###### read data - coordinate ######
+ if cfg.Data.coor_path is None:
+ coord = [np.linspace(0, 1, i) for i in spatio_shape]
+ coord = np.stack(np.meshgrid(*coord, indexing="ij"), axis=-1)
+ else:
+ coord = np.load(cfg.Data.coor_path)
+ coord = coord.astype("float32")
+ fois = fois.astype("float32")
+
+ ###### convert to tensor ######
+ fois = (
+ paddle.to_tensor(fois)
+ if not isinstance(fois, paddle.Tensor)
+ else fois
+ )
+ coord = paddle.to_tensor(coord) if not isinstance(coord, paddle.Tensor) else coord
+ N_samples = fois.shape[0]
+
+ ###### normalizer ######
+ in_normalizer = Normalizer_ts(**cfg.Data.normalizer)
+ in_normalizer.fit_normalize(
+ coord if cfg.Latent.lumped else coord.flatten(0, cfg.Latent.dims - 1)
+ )
+ out_normalizer = Normalizer_ts(**cfg.Data.normalizer)
+ out_normalizer.fit_normalize(
+ fois if cfg.Latent.lumped else fois.flatten(0, cfg.Latent.dims)
+ )
+ normed_coords = in_normalizer.normalize(coord)# 训练集就是测试集
+ normed_fois = out_normalizer.normalize(fois)
+
+ return normed_coords, normed_fois, N_samples, spatio_axis, out_normalizer
+
+
+def signal_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio_axis, out_normalizer, train_indices, test_indices):
+ cnf_model = SIRENAutodecoder_film(**cfg.CONFILD)
+ latents_model = LatentContainer(**cfg.Latent)
+
+ # 创建训练集和测试集,传入全局索引
+ train_dataset = basic_set(train_normed_fois, normed_coords, train_indices)
+ test_dataset = basic_set(test_normed_fois, normed_coords, test_indices)
+
+ criterion = paddle.nn.MSELoss()
+
+ # set loader
+ train_loader = DataLoader(
+ dataset=train_dataset, batch_size=cfg.TRAIN.batch_size, shuffle=True
+ )
+ test_loader = DataLoader(
+ dataset=test_dataset, batch_size=cfg.TRAIN.test_batch_size, shuffle=False
+ )
+ # set optimizer
+ cnf_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.cnf, weight_decay=0.0)(cnf_model)
+ latents_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.latents, weight_decay=0.0)(
+ latents_model
+ )
+ losses = []
+
+ for i in range(cfg.TRAIN.epochs):
+ cnf_model.train()
+ latents_model.train()
+ if i != 0:
+ cnf_optimizer.step()
+ cnf_optimizer.clear_grad(set_to_zero=False)
+ train_loss = []
+ for batch_coords, batch_fois, idx in train_loader:
+ idx = {"latent_x": idx}
+ batch_latent = latents_model(idx)
+ if isinstance(batch_coords, list):
+ batch_coords = [i for i in batch_coords]
+ data = {
+ "confild_x": batch_coords,
+ "latent_z": batch_latent["latent_z"],
+ }
+ batch_output = cnf_model(data)
+ loss = criterion(batch_output["confild_output"], batch_fois)
+ latents_optimizer.clear_grad(set_to_zero=False)
+ loss.backward()
+ latents_optimizer.step()
+ train_loss.append(loss)
+ epoch_loss = paddle.stack(x=train_loss).mean().item()
+ losses.append(epoch_loss)
+ print("epoch {}, train loss {}".format(i + 1, epoch_loss))
+ if i % 100 == 0:
+ test_error = []
+ cnf_model.eval()
+ latents_model.eval()
+ with paddle.no_grad():
+ for test_coords, test_fois, idx in test_loader:
+ if isinstance(test_coords, list):
+ test_coords = [i for i in test_coords]
+ prediction = out_normalizer.denormalize(
+ cnf_model(
+ {
+ "confild_x": test_coords,
+ "latent_z": latents_model({"latent_x": idx})[
+ "latent_z"
+ ],
+ }
+ )["confild_output"]
+ )
+ target = out_normalizer.denormalize(test_fois)
+ error = rMAE(prediction=prediction, target=target, dims=spatio_axis)
+ test_error.append(error)
+ test_error = paddle.concat(x=test_error).mean(axis=0)
+ print("test MAE: ", test_error)
+ if i % 100 == 0:
+ paddle.save(cnf_model.state_dict(), f"cnf_model_{i}.pdparams")
+ paddle.save(latents_model.state_dict(), f"latents_model_{i}.pdparams")
+ # 绘制损失图
+ plt.figure(figsize=(10, 6))
+ plt.plot(range(cfg.TRAIN.epochs), losses, label="Training Loss")
+
+ # 添加标题和标签
+ plt.title("Training Loss over Epochs")
+ plt.xlabel("Epochs")
+ plt.xticks(rotation=45)
+ plt.ylabel("Loss")
+
+ # 添加图例
+ plt.legend()
+
+ # 显示网格线
+ plt.grid(True)
+
+ # 保存为 PNG 格式
+ plt.savefig("case.png")
+
+ # 显示图形
+ plt.show()
+
+
+def mutil_train(cfg, normed_coords, train_normed_fois, test_normed_fois, spatio_axis, out_normalizer, train_indices, test_indices):
+ fleet.init(is_collective=True)
+ cnf_model = SIRENAutodecoder_film(**cfg.CONFILD)
+ cnf_model = fleet.distributed_model(cnf_model)
+ latents_model = LatentContainer(**cfg.Latent)
+ latents_model = fleet.distributed_model(latents_model)
+
+ # set optimizer
+ cnf_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.cnf, weight_decay=0.0)(cnf_model)
+ cnf_optimizer = fleet.distributed_optimizer(cnf_optimizer)
+ latents_optimizer = ppsci.optimizer.Adam(cfg.TRAIN.lr.latents, weight_decay=0.0)(
+ latents_model
+ )
+ latents_optimizer = fleet.distributed_optimizer(latents_optimizer)
+
+ # 创建训练集和测试集,传入全局索引
+ train_dataset = basic_set(train_normed_fois, normed_coords, train_indices)
+ test_dataset = basic_set(test_normed_fois, normed_coords, test_indices)
+
+ train_sampler = DistributedBatchSampler(
+ train_dataset, cfg.TRAIN.batch_size, shuffle=True, drop_last=True
+ )
+ train_loader = DataLoader(
+ train_dataset,
+ batch_sampler=train_sampler,
+ num_workers=cfg.TRAIN.mutil_GPU,
+ use_shared_memory=False,
+ )
+ test_sampler = DistributedBatchSampler(
+ test_dataset, cfg.TRAIN.test_batch_size, drop_last=True
+ )
+ test_loader = DataLoader(
+ test_dataset,
+ batch_sampler=test_sampler,
+ num_workers=cfg.TRAIN.mutil_GPU,
+ use_shared_memory=False,
+ )
+
+ criterion = paddle.nn.MSELoss()
+ losses = []
+
+ for i in range(cfg.TRAIN.epochs):
+ cnf_model.train()
+ latents_model.train()
+ if i != 0:
+ cnf_optimizer.step()
+ cnf_optimizer.clear_grad(set_to_zero=False)
+ train_loss = []
+ for batch_coords, batch_fois, idx in train_loader:
+ idx = {"latent_x": idx}
+ batch_latent = latents_model(idx)
+ if isinstance(batch_coords, list):
+ batch_coords = [i for i in batch_coords]
+ data = {
+ "confild_x": batch_coords,
+ "latent_z": batch_latent["latent_z"],
+ }
+ batch_output = cnf_model(data)
+ loss = criterion(batch_output["confild_output"], batch_fois)
+ latents_optimizer.clear_grad(set_to_zero=False)
+ loss.backward()
+ latents_optimizer.step()
+ train_loss.append(loss)
+ epoch_loss = paddle.stack(x=train_loss).mean().item()
+ losses.append(epoch_loss)
+ print("epoch {}, train loss {}".format(i + 1, epoch_loss))
+ if i % 100 == 0:
+ test_error = []
+ cnf_model.eval()
+ latents_model.eval()
+ with paddle.no_grad():
+ for test_coords, test_fois, idx in test_loader:
+ if isinstance(test_coords, list):
+ test_coords = [i for i in test_coords]
+ prediction = out_normalizer.denormalize(
+ cnf_model(
+ {
+ "confild_x": test_coords,
+ "latent_z": latents_model({"latent_x": idx})[
+ "latent_z"
+ ],
+ }
+ )["confild_output"]
+ )
+ target = out_normalizer.denormalize(test_fois)
+ error = rMAE(prediction=prediction, target=target, dims=spatio_axis)
+ test_error.append(error)
+ test_error = paddle.concat(x=test_error).mean(axis=0)
+ print("test MAE: ", test_error)
+ if i % 100 == 0:
+ paddle.save(cnf_model.state_dict(), f"cnf_model_{i}.pdparams")
+ paddle.save(latents_model.state_dict(), f"latents_model_{i}.pdparams")
+ # 绘制损失图
+ plt.figure(figsize=(10, 6))
+ plt.plot(range(cfg.TRAIN.epochs), losses, label="Training Loss")
+
+ # 添加标题和标签
+ plt.title("Training Loss over Epochs")
+ plt.xlabel("Epochs")
+ plt.xticks(rotation=45)
+ plt.ylabel("Loss")
+
+ # 添加图例
+ plt.legend()
+
+ # 显示网格线
+ plt.grid(True)
+
+ # 保存为 PNG 格式
+ plt.savefig("case.png")
+
+ # 显示图形
+ plt.show()
+
+
+def train(cfg):
+ # 获取GPU数量,检查是否是多卡训练
+ world_size = cfg.TRAIN.mutil_GPU
+ # 获取数据
+ normed_coords, normed_fois, N_samples, spatio_axis, out_normalizer = getdata(cfg)
+ train_normed_fois = normed_fois
+ test_normed_fois = normed_fois
+ train_indices = list(range(N_samples))
+ test_indices = list(range(N_samples))
+
+
+ if world_size > 1:
+ import paddle.distributed as dist
+ dist.init_parallel_env()
+ mutil_train(cfg, normed_coords, train_normed_fois, test_normed_fois,
+ spatio_axis, out_normalizer, train_indices, test_indices)
+ else:
+ signal_train(cfg, normed_coords, train_normed_fois, test_normed_fois,
+ spatio_axis, out_normalizer, train_indices, test_indices)
+
+
+def evaluate(cfg: DictConfig):
+ # set data
+ # normed_coords, normed_fois, N_samples, spatio_axis, out_normalizer = getdata(cfg)
+ normed_coords, normed_fois, _, spatio_axis, out_normalizer = getdata(cfg)
+
+ if len(normed_coords.shape) + 1 == len(normed_fois.shape):
+ normed_coords = paddle.tile(
+ normed_coords, [normed_fois.shape[0]] + [1] * len(normed_coords.shape)
+ )
+
+ idx = paddle.to_tensor(
+ np.array([i for i in range(normed_fois.shape[0])]), dtype="int64"
+ )
+ # set model
+ confild = SIRENAutodecoder_film(**cfg.CONFILD)
+ latent = LatentContainer(**cfg.Latent)
+ logger.info(
+ "Loading pretrained model from {}".format(
+ cfg.EVAL.confild_pretrained_model_path
+ )
+ )
+ ppsci.utils.save_load.load_pretrain(
+ confild,
+ cfg.EVAL.confild_pretrained_model_path,
+ )
+ logger.info(
+ "Loading pretrained model from {}".format(cfg.EVAL.latent_pretrained_model_path)
+ )
+ ppsci.utils.save_load.load_pretrain(
+ latent,
+ cfg.EVAL.latent_pretrained_model_path,
+ )
+ latent_test_pred = latent({"latent_x": idx})
+ y_test_pred = []
+ for i in range(normed_coords.shape[0]):
+ y_test_pred.append(
+ confild(
+ {
+ "confild_x": normed_coords[i],
+ "latent_z": latent_test_pred["latent_z"][i],
+ }
+ )["confild_output"].numpy()
+ )
+ y_test_pred = paddle.to_tensor(np.array(y_test_pred))
+
+ y_test_pred = out_normalizer.denormalize(y_test_pred)
+ y_test = out_normalizer.denormalize(normed_fois)
+
+ logger.info("Result is {}".format(y_test.numpy()))
+
+
+def inference(cfg):
+ # 获取分割后的数据集
+ normed_coords, normed_fois, _, _, _ = getdata(cfg)
+ if len(normed_coords.shape) + 1 == len(normed_fois.shape):
+ normed_coords = paddle.tile(
+ normed_coords, [normed_fois.shape[0]] + [1] * len(normed_coords.shape)
+ )
+
+ fois_len = normed_fois.shape[0]
+ idxs = np.array([i for i in range(fois_len)])
+ from deploy import python_infer
+
+ latent_predictor = python_infer.GeneralPredictor(cfg.INFER.Latent)
+ input_dict = {"latent_x": idxs}
+ output_dict = latent_predictor.predict(input_dict, cfg.INFER.batch_size)
+
+ cnf_predictor = python_infer.GeneralPredictor(cfg.INFER.Confild)
+ input_dict = {
+ "confild_x": normed_coords.numpy(),
+ "latent_z": list(output_dict.values())[0],
+ }
+ output_dict = cnf_predictor.predict(input_dict, cfg.INFER.batch_size)
+
+ logger.info("Result is {}".format(output_dict["fetch_name_0"]))
+
+
+def export(cfg):
+ # set model
+ cnf_model = SIRENAutodecoder_film(**cfg.CONFILD)
+ latent_model = LatentContainer(**cfg.Latent)
+ # initialize solver
+ latnet_solver = ppsci.solver.Solver(
+ latent_model,
+ pretrained_model_path=cfg.INFER.Latent.INFER.pretrained_model_path,
+ )
+ cnf_solver = ppsci.solver.Solver(
+ cnf_model,
+ pretrained_model_path=cfg.INFER.Confild.INFER.pretrained_model_path,
+ )
+ # export model
+ from paddle.static import InputSpec
+
+ input_spec = [
+ {key: InputSpec([None], "int64", name=key) for key in latent_model.input_keys},
+ ]
+ cnf_input_spec = [
+ {
+ cnf_model.input_keys[0]: InputSpec(
+ [None] + list(cfg.INFER.Confild.INFER.coord_shape),
+ "float32",
+ name=cnf_model.input_keys[0],
+ ),
+ cnf_model.input_keys[1]: InputSpec(
+ [None] + list(cfg.INFER.Confild.INFER.latents_shape),
+ "float32",
+ name=cnf_model.input_keys[1],
+ ),
+ }
+ ]
+ cnf_solver.export(cnf_input_spec, cfg.INFER.Confild.INFER.export_path)
+ latnet_solver.export(input_spec, cfg.INFER.Latent.INFER.export_path)
+
+
+@hydra.main(version_base=None, config_path="./conf", config_name="confild_case1.yaml")
+def main(cfg: DictConfig):
+ if cfg.mode == "train":
+ train(cfg)
+ elif cfg.mode == "eval":
+ evaluate(cfg)
+ elif cfg.mode == "infer":
+ inference(cfg)
+ elif cfg.mode == "export":
+ export(cfg)
+ else:
+ raise ValueError(
+ f"cfg.mode should in ['train', 'eval', 'infer', 'export'], but got '{cfg.mode}'"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/confild/resample.py b/examples/confild/resample.py
new file mode 100644
index 0000000000..c1cd5b8eee
--- /dev/null
+++ b/examples/confild/resample.py
@@ -0,0 +1,154 @@
+from abc import ABC, abstractmethod
+
+import numpy as np
+import paddle as th
+import paddle.distributed as dist
+
+# TODO
+def create_named_schedule_sampler(name, diffusion):
+ """
+ Create a ScheduleSampler from a library of pre-defined samplers.
+
+ :param name: the name of the sampler.
+ :param diffusion: the diffusion object to sample for.
+ """
+ if name == "uniform":
+ return UniformSampler(diffusion)
+ elif name == "loss-second-moment":
+ return LossSecondMomentResampler(diffusion)
+ else:
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
+
+
+class ScheduleSampler(ABC):
+ """
+ A distribution over timesteps in the diffusion process, intended to reduce
+ variance of the objective.
+
+ By default, samplers perform unbiased importance sampling, in which the
+ objective's mean is unchanged.
+ However, subclasses may override sample() to change how the resampled
+ terms are reweighted, allowing for actual changes in the objective.
+ """
+
+ @abstractmethod
+ def weights(self):
+ """
+ Get a numpy array of weights, one per diffusion step.
+
+ The weights needn't be normalized, but must be positive.
+ """
+
+ def sample(self, batch_size):
+ """
+ Importance-sample timesteps for a batch.
+
+ :param batch_size: the number of timesteps.
+ :return: a tuple (timesteps, weights):
+ - timesteps: a tensor of timestep indices.
+ - weights: a tensor of weights to scale the resulting losses.
+ """
+ w = self.weights()
+ p = w / np.sum(w)
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
+ indices = th.to_tensor(indices_np, dtype='int64')
+ weights_np = 1 / (len(p) * p[indices_np])
+ weights = th.to_tensor(weights_np, dtype='float32')
+ return indices, weights
+
+
+class UniformSampler(ScheduleSampler):
+ def __init__(self, diffusion):
+ self.diffusion = diffusion
+ self._weights = np.ones([diffusion.num_timesteps])
+
+ def weights(self):
+ return self._weights
+
+
+class LossAwareSampler(ScheduleSampler):
+ def update_with_local_losses(self, local_ts, local_losses):
+ """
+ Update the reweighting using losses from a model.
+
+ Call this method from each rank with a batch of timesteps and the
+ corresponding losses for each of those timesteps.
+ This method will perform synchronization to make sure all of the ranks
+ maintain the exact same reweighting.
+
+ :param local_ts: an integer Tensor of timesteps.
+ :param local_losses: a 1D Tensor of losses.
+ """
+
+ batch_sizes = [
+ th.to_tensor([0], dtype=th.int32, place=local_ts.device)
+ for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(
+ batch_sizes,
+ th.to_tensor([len(local_ts)], dtype=th.int32, place=local_ts.device),
+ )
+
+ # Pad all_gather batches to be the maximum batch size.
+ batch_sizes = [x.item() for x in batch_sizes]
+ max_bs = max(batch_sizes)
+
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
+ dist.all_gather(timestep_batches, local_ts)
+ dist.all_gather(loss_batches, local_losses)
+ timesteps = [
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
+ ]
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
+ self.update_with_all_losses(timesteps, losses)
+
+ @abstractmethod
+ def update_with_all_losses(self, ts, losses):
+ """
+ Update the reweighting using losses from a model.
+
+ Sub-classes should override this method to update the reweighting
+ using losses from the model.
+
+ This method directly updates the reweighting without synchronizing
+ between workers. It is called by update_with_local_losses from all
+ ranks with identical arguments. Thus, it should have deterministic
+ behavior to maintain state across workers.
+
+ :param ts: a list of int timesteps.
+ :param losses: a list of float losses, one per timestep.
+ """
+
+
+class LossSecondMomentResampler(LossAwareSampler):
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
+ self.diffusion = diffusion
+ self.history_per_term = history_per_term
+ self.uniform_prob = uniform_prob
+ self._loss_history = np.zeros(
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
+ )
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
+
+ def weights(self):
+ if not self._warmed_up():
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
+ weights /= np.sum(weights)
+ weights *= 1 - self.uniform_prob
+ weights += self.uniform_prob / len(weights)
+ return weights
+
+ def update_with_all_losses(self, ts, losses):
+ for t, loss in zip(ts, losses):
+ if self._loss_counts[t] == self.history_per_term:
+ # Shift out the oldest loss term.
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
+ self._loss_history[t, -1] = loss
+ else:
+ self._loss_history[t, self._loss_counts[t]] = loss
+ self._loss_counts[t] += 1
+
+ def _warmed_up(self):
+ return (self._loss_counts == self.history_per_term).all()
diff --git a/examples/confild/un_confild.py b/examples/confild/un_confild.py
new file mode 100644
index 0000000000..7374c85a6e
--- /dev/null
+++ b/examples/confild/un_confild.py
@@ -0,0 +1,1019 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# 导入必要的库
+from abc import ABC, abstractmethod
+import copy
+import enum
+import functools
+import math
+import hydra
+import matplotlib.pyplot as plt
+import numpy as np
+import paddle
+import os
+from omegaconf import DictConfig
+from resample import UniformSampler, LossAwareSampler
+
+from ppsci.arch import UNetModel
+from ppsci.arch import SIRENAutodecoder_film
+from ppsci.arch import SpacedDiffusion
+from ppsci.arch import ModelVarType
+from ppsci.arch import ModelMeanType
+from ppsci.utils import logger
+from ppsci.arch import LossType
+
+
+def mean_flat(tensor):
+ """
+ 计算张量除批次维度外所有维度的平均值
+
+ 参数:
+ tensor: 输入张量
+
+ 返回:
+ 除批次维度外所有维度的平均值
+ """
+ return tensor.mean(axis=list(range(1, len(tensor.shape))))
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ 计算两个高斯分布之间的KL散度
+
+ 参数:
+ mean1: 第一个高斯分布的均值
+ logvar1: 第一个高斯分布的对数方差
+ mean2: 第二个高斯分布的均值
+ logvar2: 第二个高斯分布的对数方差
+
+ 返回:
+ 两个高斯分布之间的KL散度
+ """
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + paddle.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * paddle.exp(-logvar2)
+ )
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ 从一维numpy数组中为一批索引提取值
+
+ 参数:
+ arr: 一维numpy数组
+ timesteps: 时间步索引
+ broadcast_shape: 广播形状
+
+ 返回:
+ 提取并广播后的张量
+ """
+ # 修复变量名错误
+ res = paddle.to_tensor(arr)[timesteps].astype(timesteps.dtype)
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res.expand(broadcast_shape)
+
+
+# 添加用于存储训练和验证损失的全局变量
+train_losses = [] # 存储训练损失
+valid_losses = [] # 存储验证损失
+
+
+def create_model(
+ image_size,
+ num_channels,
+ num_res_blocks,
+ dims=2,
+ out_channels=1,
+ channel_mult=None,
+ learn_sigma=False,
+ class_cond=False,
+ use_checkpoint=False,
+ attention_resolutions="16",
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ dropout=0,
+ resblock_updown=False,
+ use_fp16=False,
+ use_new_attention_order=False,
+):
+ """
+ 创建UNet模型
+
+ 参数:
+ image_size: 图像尺寸
+ num_channels: 模型通道数
+ num_res_blocks: 每个下采样级别的残差块数
+ dims: 数据维度(1=1D, 2=2D, 3=3D)
+ out_channels: 输出张量的通道数
+ channel_mult: 每个级别的通道乘数
+ learn_sigma: 是否学习方差
+ class_cond: 是否使用类别条件
+ use_checkpoint: 是否启用梯度检查点
+ attention_resolutions: 应用注意力的下采样率
+ num_heads: 注意力头数
+ num_head_channels: 每个注意力头的通道数
+ num_heads_upsample: 上采样块的注意力头数
+ use_scale_shift_norm: 是否使用FiLM-like调节
+ dropout: Dropout概率
+ resblock_updown: 是否使用残差块进行重采样
+ use_fp16: 是否使用float16精度
+ use_new_attention_order: 是否使用优化的注意力模式
+
+ 返回:
+ UNet模型实例
+ """
+ if channel_mult is None:
+ if image_size == 512:
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
+ elif image_size == 256:
+ channel_mult = (1, 1, 2, 2, 4, 4)
+ elif image_size == 128:
+ channel_mult = (1, 1, 2, 3, 4)
+ elif image_size == 64:
+ channel_mult = (1, 2, 3, 4)
+ else:
+ raise ValueError(f"unsupported image size: {image_size}")
+ else:
+ # 修复channel_mult处理逻辑,确保类型正确
+ if isinstance(channel_mult, str):
+ channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
+
+ attention_ds = []
+ for res in attention_resolutions.split(","):
+ attention_ds.append(image_size // int(res))
+
+ return UNetModel(
+ image_size=image_size,
+ in_channels=out_channels,
+ model_channels=num_channels,
+ out_channels=(out_channels if not learn_sigma else 2*out_channels),#(3 if not learn_sigma else 6),
+ num_res_blocks=num_res_blocks,
+ attention_resolutions=tuple(attention_ds),
+ dropout=dropout,
+ channel_mult=channel_mult,
+ num_classes=(1000 if class_cond else None),
+ use_checkpoint=use_checkpoint,
+ use_fp16=use_fp16,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ num_heads_upsample=num_heads_upsample,
+ use_scale_shift_norm=use_scale_shift_norm,
+ resblock_updown=resblock_updown,
+ use_new_attention_order=use_new_attention_order,
+ dims=dims
+ )
+
+
+# class LossType(enum.Enum):
+# """
+# 损失类型枚举
+# """
+# MSE = enum.auto() # 使用原始MSE损失(学习方差时使用KL)
+# RESCALED_MSE = (
+# enum.auto()
+# ) # 使用原始MSE损失(学习方差时使用RESCALED_KL)
+# KL = enum.auto() # 使用变分下界
+# RESCALED_KL = enum.auto() # 类似KL,但重新缩放以估计完整的VLB
+
+# def is_vb(self):
+# """
+# 判断是否为变分下界损失
+# """
+# return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ 获取命名的beta调度
+
+ 参数:
+ schedule_name: 调度名称("linear"或"cosine")
+ num_diffusion_timesteps: 扩散步骤数
+
+ 返回:
+ beta值数组
+ """
+ if schedule_name == "linear":
+ # Ho等人的线性调度,扩展为适用于任何数量的扩散步骤
+ scale = 1000 / num_diffusion_timesteps
+ beta_start = scale * 0.0001
+ beta_end = scale * 0.02
+ return np.linspace(
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
+ )
+ elif schedule_name == "cosine":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ 基于alpha_bar创建betas
+
+ 参数:
+ num_diffusion_timesteps: 扩散步骤数
+ alpha_bar: 累积alpha值函数
+ max_beta: beta的最大值
+
+ 返回:
+ beta值数组
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ 在基础扩散过程中跳过步骤的时间步空间化
+
+ 参数:
+ num_timesteps: 原始时间步数
+ section_counts: 每个部分的时间步数
+
+ 返回:
+ 保留的时间步集合
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+def create_gaussian_diffusion(
+ *,
+ steps=1000,
+ learn_sigma=False,
+ sigma_small=False,
+ noise_schedule="linear",
+ use_kl=False,
+ predict_xstart=False,
+ rescale_timesteps=False,
+ rescale_learned_sigmas=False,
+ timestep_respacing="",
+):
+ """
+ 创建高斯扩散过程
+
+ 参数:
+ steps: 扩散步骤数
+ learn_sigma: 是否学习方差
+ sigma_small: 是否使用小方差
+ noise_schedule: 噪声调度("linear"或"cosine")
+ use_kl: 是否使用KL损失
+ predict_xstart: 是否预测初始x
+ rescale_timesteps: 是否重新缩放时间步
+ rescale_learned_sigmas: 是否重新缩放学习的sigma
+ timestep_respacing: 时间步重新间隔
+
+ 返回:
+ 高斯扩散过程实例
+ """
+ betas = get_named_beta_schedule(noise_schedule, steps)
+ if use_kl:
+ loss_type = LossType.RESCALED_KL
+ elif rescale_learned_sigmas:
+ loss_type = LossType.RESCALED_MSE
+ else:
+ loss_type = LossType.MSE
+ if not timestep_respacing:
+ timestep_respacing = [steps]
+ return SpacedDiffusion(
+ use_timesteps=space_timesteps(steps, timestep_respacing),
+ betas=betas,
+ model_mean_type=(
+ ModelMeanType.EPSILON if not predict_xstart else ModelMeanType.START_X
+ ),
+ model_var_type=(
+ (
+ ModelVarType.FIXED_LARGE
+ if not sigma_small
+ else ModelVarType.FIXED_SMALL
+ )
+ if not learn_sigma
+ else ModelVarType.LEARNED_RANGE
+ ),
+ loss_type=loss_type,
+ rescale_timesteps=rescale_timesteps,
+ )
+
+
+def load_elbow_flow(path):
+ """
+ 加载肘管流数据
+
+ 参数:
+ path: 数据文件路径
+
+ 返回:
+ 肘管流数据(从索引1开始)
+ """
+ return np.load(f"{path}")[1:]
+
+
+def load_channel_flow(
+ path,
+ t_start=0,
+ t_end=1200,
+ t_every=1,
+):
+ """
+ 加载通道流数据
+
+ 参数:
+ path: 数据文件路径
+ t_start: 起始时间步
+ t_end: 结束时间步
+ t_every: 采样间隔
+
+ 返回:
+ 通道流数据
+ """
+ return np.load(f"{path}")[t_start:t_end:t_every]
+
+
+def load_periodic_hill_flow(path):
+ """
+ 加载周期性山丘流数据
+
+ 参数:
+ path: 数据文件路径
+
+ 返回:
+ 周期性山丘流数据
+ """
+ data = np.load(f"{path}")
+ return data
+
+
+def load_3d_flow(path):
+ """
+ 加载3D流数据
+
+ 参数:
+ path: 数据文件路径
+
+ 返回:
+ 3D流数据
+ """
+ data = np.load(f"{path}")
+ return data
+
+
+class Normalizer_ts(object):
+ """
+ 时间序列归一化器
+ """
+ def __init__(self, params=[], method="-11", dim=None):
+ """
+ 初始化归一化器
+
+ 参数:
+ params: 归一化参数
+ method: 归一化方法("-11", "01", "ms", "none")
+ dim: 归一化维度
+ """
+ self.params = params
+ self.method = method
+ self.dim = dim
+
+ def fit_normalize(self, data):
+ """
+ 拟合并归一化数据
+
+ 参数:
+ data: 输入数据
+
+ 返回:
+ 归一化后的数据
+ """
+ assert type(data) == paddle.Tensor
+ if len(self.params) == 0:
+ if self.method == "-11" or self.method == "01":
+ if self.dim is None:
+ self.params = paddle.max(x=data), paddle.min(x=data)
+ else:
+ self.params = (
+ paddle.max(keepdim=True, x=data, axis=self.dim),
+ paddle.argmax(keepdim=True, x=data, axis=self.dim),
+ )[0], (
+ paddle.min(keepdim=True, x=data, axis=self.dim),
+ paddle.argmin(keepdim=True, x=data, axis=self.dim),
+ )[
+ 0
+ ]
+ elif self.method == "ms":
+ if self.dim is None:
+ self.params = paddle.mean(x=data, axis=self.dim), paddle.std(
+ x=data, axis=self.dim
+ )
+ else:
+ self.params = paddle.mean(
+ x=data, axis=self.dim, keepdim=True
+ ), paddle.std(x=data, axis=self.dim, keepdim=True)
+ elif self.method == "none":
+ self.params = None
+ return self.fnormalize(data, self.params, self.method)
+
+ def normalize(self, new_data):
+ """
+ 归一化新数据
+
+ 参数:
+ new_data: 新数据
+
+ 返回:
+ 归一化后的数据
+ """
+ if not new_data.place == self.params[0].place:
+ self.params = self.params[0], self.params[1]
+ return self.fnormalize(new_data, self.params, self.method)
+
+ def denormalize(self, new_data_norm):
+ """
+ 反归一化数据
+
+ 参数:
+ new_data_norm: 归一化后的数据
+
+ 返回:
+ 反归一化后的数据
+ """
+ if not new_data_norm.place == self.params[0].place:
+ self.params = self.params[0], self.params[1]
+ return self.fdenormalize(new_data_norm, self.params, self.method)
+
+ def get_params(self):
+ """
+ 获取归一化参数
+ """
+ if self.method == "ms":
+ print("returning mean and std")
+ elif self.method == "01":
+ print("returning max and min")
+ elif self.method == "-11":
+ print("returning max and min")
+ elif self.method == "none":
+ print("do nothing")
+ return self.params
+
+ @staticmethod
+ def fnormalize(data, params, method):
+ """
+ 执行归一化
+
+ 参数:
+ data: 输入数据
+ params: 归一化参数
+ method: 归一化方法
+
+ 返回:
+ 归一化后的数据
+ """
+ if method == "-11":
+ return (data - params[1]) / (
+ params[0] - params[1]
+ ) * 2 - 1
+ elif method == "01":
+ return (data - params[1]) / (
+ params[0] - params[1]
+ )
+ elif method == "ms":
+ return (data - params[0]) / params[1]
+ elif method == "none":
+ return data
+
+ @staticmethod
+ def fdenormalize(data_norm, params, method):
+ """
+ 执行反归一化
+
+ 参数:
+ data_norm: 归一化后的数据
+ params: 归一化参数
+ method: 归一化方法
+
+ 返回:
+ 反归一化后的数据
+ """
+ if method == "-11":
+ return (data_norm + 1) / 2 * (params[0] - params[1]) + params[1]
+ elif method == "01":
+ return data_norm * (
+ params[0] - params[1]
+ ) + params[1]
+ elif method == "ms":
+ return data_norm * params[1] + params[0]
+ elif method == "none":
+ return data_norm
+
+
+def create_slim(cfg):
+ """
+ 创建SLIM模型
+
+ 参数:
+ cfg: 配置对象
+
+ 返回:
+ CNF模型、输入归一化器、输出归一化器和坐标
+ """
+ ###### read data - fois ######
+ if cfg.CNF.load_data_fn == "load_3d_flow":
+ fois = load_3d_flow(cfg.CNF.data_path)
+ elif cfg.CNF.load_data_fn == "load_elbow_flow":
+ fois = load_elbow_flow(cfg.CNF.data_path)
+ elif cfg.CNF.load_data_fn == "load_channel_flow":
+ fois = load_channel_flow(cfg.CNF.data_path)
+ elif cfg.CNF.load_data_fn == "load_periodic_hill_flow":
+ fois = load_periodic_hill_flow(cfg.CNF.data_path)
+ else:
+ fois = np.load(cfg.CNF.data_path)
+
+ # 计算空间形状和轴
+ spatio_shape = fois.shape[1:-1]
+
+ ###### read data - coordinate ######
+ if cfg.CNF.coor_path is None:
+ coord = [np.linspace(0, 1, i) for i in spatio_shape]
+ coord = np.stack(np.meshgrid(*coord, indexing="ij"), axis=-1)
+ else:
+ coord = np.load(cfg.CNF.coor_path)
+ coord = coord.astype("float32")
+ fois = fois.astype("float32")
+
+ ###### convert to tensor ######
+ fois = (
+ paddle.to_tensor(fois)
+ if not isinstance(fois, paddle.Tensor)
+ else fois
+ )
+ coord = paddle.to_tensor(coord) if not isinstance(coord, paddle.Tensor) else coord
+ N_samples = fois.shape[0]
+
+ ###### normalizer ######
+ in_normalizer = Normalizer_ts(**cfg.CNF.normalizer)
+ out_normalizer = Normalizer_ts(**cfg.CNF.normalizer)
+ # 使用最新的模型参数
+ norm_params = paddle.load(cfg.CNF.normalizer_params_path)
+ in_normalizer.params = norm_params["x_normalizer_params"]
+ out_normalizer.params = norm_params["y_normalizer_params"]
+
+ cnf_model = SIRENAutodecoder_film(**cfg.CNF.CONFILD)
+
+ return cnf_model, in_normalizer, out_normalizer, coord
+
+
+def dl_iter(dl):
+ """
+ 数据加载器迭代器
+
+ 参数:
+ dl: 数据加载器
+
+ 返回:
+ 无限迭代数据加载器
+ """
+ while True:
+ yield from dl
+
+
+def train(cfg):
+ """
+ 训练函数
+
+ 参数:
+ cfg: 配置对象
+ """
+ # create parameters
+ batch_size = cfg.TRAIN.batch_size
+ test_batch_size = cfg.TRAIN.test_batch_size
+ ema_rate = cfg.TRAIN.ema_rate
+ ema_rate = (
+ [ema_rate]
+ if isinstance(ema_rate, float)
+ else [float(x) for x in ema_rate.split(",")]
+ )
+
+ lr_anneal_steps = cfg.TRAIN.lr_anneal_steps
+ final_lr = cfg.TRAIN.final_lr
+ step = 0
+ resume_step = 0
+ microbatch = cfg.TRAIN.microbatch if cfg.TRAIN.microbatch > 0 else batch_size
+
+ ## Data Preprocessing
+ train_data = np.load(cfg.DATA.train_data)
+ valid_data = np.load(cfg.DATA.valid_data)
+ max_val, min_val = np.max(train_data, keepdims=True), np.min(train_data, keepdims=True)
+ norm_train_data = -1 + (train_data - min_val)*2. / (max_val - min_val)
+ norm_valid_data = -1 + (valid_data - min_val)*2. / (max_val - min_val)
+
+ norm_train_data = paddle.to_tensor(norm_train_data[:, None, ...])
+ norm_valid_data = paddle.to_tensor(norm_valid_data[:, None, ...])
+
+ dl_train = dl_iter(paddle.io.DataLoader(paddle.io.TensorDataset(norm_train_data), batch_size=batch_size, shuffle=True))
+ dl_valid = dl_iter(paddle.io.DataLoader(paddle.io.TensorDataset(norm_valid_data), batch_size=test_batch_size, shuffle=True))
+
+ unet_model = create_model(image_size=cfg.UNET.image_size,
+ num_channels= cfg.UNET.num_channels,
+ num_res_blocks= cfg.UNET.num_res_blocks,
+ num_heads=cfg.UNET.num_heads,
+ num_head_channels=cfg.UNET.num_head_channels,
+ attention_resolutions=cfg.UNET.attention_resolutions,
+ channel_mult=cfg.UNET.channel_mult
+ )
+ diff_model = create_gaussian_diffusion(steps=cfg.Diff.steps,
+ noise_schedule=cfg.Diff.noise_schedule
+ )
+
+ # 初始化AdamW优化器
+ opt = paddle.optimizer.AdamW(
+ parameters=unet_model.parameters(), learning_rate=cfg.TRAIN.lr, weight_decay=cfg.TRAIN.weight_decay
+ )
+
+ schedule_sampler = UniformSampler(diff_model)
+
+ # 初始化EMA参数
+ ema_params = []
+ for _ in range(len(ema_rate)):
+ ema_param_dict = {}
+ for name, param in unet_model.named_parameters():
+ ema_param_dict[name] = copy.deepcopy(param.detach())
+ ema_params.append(ema_param_dict)
+
+ # 清空损失记录
+ global train_losses, valid_losses
+ train_losses.clear()
+ valid_losses.clear()
+
+ while (
+ not lr_anneal_steps
+ or step + resume_step < lr_anneal_steps
+ ):
+ cond = {}
+ # 获取下一个训练批次和验证批次的数据
+ train_batch = next(dl_train)
+ valid_batch = next(dl_valid)
+ # 前向传播
+ unet_model.train()
+ # 清零梯度(使用clear_grad更高效)
+ unet_model.clear_grad()
+
+ for i in range(0, len(train_batch), microbatch):
+ # 获取当前微批次数据
+ micro = train_batch[i : i + microbatch]
+ micro_cond = {
+ k: v[i : i + microbatch]
+ for k, v in cond.items()
+ }
+
+ # 从调度采样器中采样时间步
+ t, weights = schedule_sampler.sample(len(micro))
+
+ # 创建部分应用的损失计算函数
+ # 注意:micro已经是tensor(来自DataLoader),无需再次转换
+ compute_losses = functools.partial(
+ diff_model.training_losses,
+ unet_model,
+ micro,
+ t,
+ model_kwargs=micro_cond
+ )
+
+ # 计算损失
+ losses = compute_losses()
+
+ # 如果使用损失感知采样器,则更新本地损失
+ if isinstance(schedule_sampler, LossAwareSampler):
+ schedule_sampler.update_with_local_losses(
+ t, losses["loss"].detach()
+ )
+
+ # 计算加权平均损失
+ loss = (losses["loss"] * weights).mean()
+
+ # 记录损失字典(排除非张量类型的键)
+ log_loss_dict(
+ diff_model, t, {k: v * weights for k, v in losses.items() if isinstance(v, paddle.Tensor)}, is_valid=False
+ )
+
+ # 反向传播
+ # unet_model.backward(loss)
+ loss.backward()
+
+ # 不计算梯度,节省内存,设置模型为评估模式
+ unet_model.eval()
+ with paddle.no_grad():
+ # 聚合所有微批次的验证损失
+ all_valid_losses = []
+
+ # 同样分解成微批次处理
+ for i in range(0, len(valid_batch), microbatch):
+ # 获取当前微批次数据
+ micro = valid_batch[i : i + microbatch]
+ micro_cond = {
+ k: v[i : i + microbatch]
+ for k, v in cond.items()
+ }
+
+ # 采样时间步
+ t, weights = schedule_sampler.sample(len(micro))
+
+ # 创建部分应用的损失计算函数
+ # 注意:micro已经是tensor(来自DataLoader),无需再次转换
+ compute_losses = functools.partial(
+ diff_model.training_losses,
+ unet_model,
+ micro,
+ t,
+ model_kwargs=micro_cond,
+ valid=True
+ )
+
+ # 计算验证损失
+ losses = compute_losses()
+
+ # 记录验证损失(排除非张量类型的键,如布尔标记等)
+ valid_loss_dict = {k: v * weights for k, v in losses.items() if isinstance(v, paddle.Tensor)}
+ # 验证时不添加到列表,而是在外部聚合后统一添加
+ log_loss_dict(
+ diff_model, t, valid_loss_dict, is_valid=True, add_to_list=False
+ )
+
+ # 收集损失用于聚合
+ if "loss" in valid_loss_dict:
+ all_valid_losses.append(valid_loss_dict["loss"].mean().item())
+
+ # 聚合整个验证批次的平均损失并添加一次
+ if len(all_valid_losses) > 0:
+ avg_valid_loss = sum(all_valid_losses) / len(all_valid_losses)
+ valid_losses.append(avg_valid_loss)
+
+ # 验证结束后切换回训练模式
+ unet_model.train()
+
+ grad_norm, param_norm = _compute_norms(unet_model)
+ opt.step()
+ # took_step = unet_model.optimize(opt)
+ # 更新ema参数
+ _update_ema(ema_rate, ema_params, unet_model)
+ # 更新学习率
+ _anneal_lr(lr_anneal_steps, step, resume_step, opt, final_lr, cfg.TRAIN.lr)
+
+ step += 1
+
+ # 每100步打印一次训练和验证损失
+ if step % 100 == 0:
+ if len(train_losses) > 0 and len(valid_losses) > 0:
+ print(f"Step {step}: Train Loss: {train_losses[-1]:.6f}, Valid Loss: {valid_losses[-1]:.6f}")
+
+ # 保存模型
+ paddle.save(unet_model.state_dict(), "unet.pdparams")
+
+ # 绘制训练和验证损失曲线
+ plot_losses()
+
+
+def plot_losses():
+ """
+ 绘制训练和验证损失曲线
+ """
+ if len(train_losses) == 0 or len(valid_losses) == 0:
+ print("没有足够的数据来绘制损失曲线")
+ return
+
+ plt.figure(figsize=(10, 6))
+ plt.plot(train_losses, label='Training Loss', alpha=0.8)
+ plt.plot(valid_losses, label='Validation Loss', alpha=0.8)
+ plt.xlabel('Training Steps')
+ plt.ylabel('Loss')
+ plt.title('Training and Validation Loss')
+ plt.legend()
+ plt.grid(True)
+ plt.tight_layout()
+
+ # 保存图像
+ plt.savefig('loss_curve.png', dpi=300, bbox_inches='tight')
+ print("损失曲线已保存为 loss_curve.png")
+
+ # 显示图像
+ plt.show()
+
+
+def _compute_norms(model, grad_scale=1.0):
+ """
+ 计算模型参数和梯度的范数
+
+ 参数:
+ model: 模型
+ grad_scale: 梯度缩放因子
+
+ 返回:
+ 梯度范数和参数范数
+ """
+ grad_norm = 0.0
+ param_norm = 0.0
+ for p in model.parameters():
+ with paddle.no_grad():
+ param_norm += paddle.norm(p, p=2, dtype=paddle.float32).item() ** 2
+ if p.grad is not None:
+ grad_norm += paddle.norm(p.grad, p=2, dtype=paddle.float32).item() ** 2
+ return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
+
+
+def _update_ema(ema_rate, ema_params, source_model):
+ """
+ 更新EMA(指数移动平均)参数
+ EMA有助于提高生成质量,减少模型权重噪声
+
+ 参数:
+ ema_rate: EMA衰减率列表
+ ema_params: EMA参数字典列表
+ source_model: 源模型
+ """
+ for rate, target_params_dict in zip(ema_rate, ema_params):
+ for name, target_param in target_params_dict.items():
+ source_param = dict(source_model.named_parameters())[name]
+ updated = target_param.detach() * rate + source_param.detach() * (1 - rate)
+ target_param.set_value(updated)
+
+
+def _anneal_lr(lr_anneal_steps, step, resume_step, opt, final_lr, lr):
+ """
+ 学习率退火调整
+ 根据训练进度线性降低学习率
+
+ 参数:
+ lr_anneal_steps: 学习率退火步数
+ step: 当前步数
+ resume_step: 恢复步数
+ opt: 优化器
+ final_lr: 最终学习率
+ lr: 初始学习率
+ """
+ if not lr_anneal_steps:
+ return
+ frac_done = (step + resume_step) / lr_anneal_steps
+ new_lr = final_lr * (frac_done) + lr * (1 - frac_done)
+ opt.set_lr(new_lr)
+
+
+def log_loss_dict(diffusion, ts, losses, is_valid=False, add_to_list=True):
+ """
+ 记录损失字典的日志
+
+ 参数:
+ diffusion: 扩散模型对象
+ ts: 时间步张量
+ losses: 损失字典
+ is_valid: 是否为验证损失
+ add_to_list: 是否将损失添加到全局列表中(用于验证时聚合控制)
+ """
+ for key, values in losses.items():
+ # 使用logger.info替代logger.logkv_mean记录平均损失值
+ logger.info(f"{key}: {values.mean().item():.6f}")
+ # 记录分位数(特别是四个四分位数)
+ for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
+ quartile = int(4 * sub_t / diffusion.num_timesteps)
+ logger.info(f"{key}_q{quartile}: {sub_loss:.6f}")
+
+ # 记录训练和验证损失到全局列表
+ if key == "loss" and add_to_list:
+ if is_valid:
+ valid_losses.append(values.mean().item())
+ else:
+ train_losses.append(values.mean().item())
+
+
+def evaluate(cfg):
+ """
+ 评估函数
+
+ 参数:
+ cfg: 配置对象
+ """
+ ## Create model and diffusion
+ unet_model = create_model(image_size=cfg.UNET.image_size,
+ num_channels=cfg.UNET.num_channels,
+ num_res_blocks=cfg.UNET.num_res_blocks,
+ num_heads=cfg.UNET.num_heads,
+ num_head_channels=cfg.UNET.num_head_channels,
+ attention_resolutions=cfg.UNET.attention_resolutions
+ )
+
+ unet_model.set_state_dict(paddle.load(cfg.UNET.ema_path))
+
+ diff_model = create_gaussian_diffusion(steps=cfg.Diff.steps,
+ noise_schedule=cfg.Diff.noise_schedule
+ )
+
+ sample_fn = diff_model.p_sample_loop
+ gen_latents = sample_fn(unet_model, (cfg.EVAL.test_batch_size, 1, cfg.EVAL.time_length, cfg.EVAL.latent_length))[:, 0]
+
+ max_val, min_val = cfg.DATA.max_val, cfg.DATA.min_val#np.load(cfg.DATA.max_val), np.load(cfg.DATA.min_val)
+ max_val, min_val = paddle.to_tensor(max_val), paddle.to_tensor(min_val)
+ gen_latents = (gen_latents + 1)*(max_val - min_val)/2. + min_val
+
+ # 获取模型
+ nf, in_normalizer, out_normalizer, coord = create_slim(cfg)
+ nf.set_state_dict(paddle.load(cfg.CNF.model_path))
+ coord = in_normalizer.normalize(coord)
+
+ batch_size = 1 # if you are limited by your GPU Memory, please change the batch_size variable accordingly
+ n_samples = gen_latents.shape[0]
+ gen_fields = []
+
+ for sample_index in range(n_samples):
+ for i in range(gen_latents.shape[1]//batch_size):
+ new_latents = gen_latents[sample_index, i*batch_size:(i+1)*batch_size]
+ # coord = in_normalizer.normalize(coord)
+ if len(coord.shape) > 2:
+ new_latents = new_latents[:, None, None]
+ else:
+ new_latents = new_latents[:, None]
+ input_data = {
+ "confild_x": coord,
+ "latent_z": new_latents
+ }
+ out = nf(input_data)["confild_output"]
+ out = out_normalizer.denormalize(out)
+ gen_fields.append(out.detach().cpu().numpy())
+
+ gen_fields = np.concatenate(gen_fields)
+
+ np.save(cfg.save_path, gen_fields)
+
+
+@hydra.main(version_base=None, config_path="./conf", config_name="un_confild_case1.yaml")
+def main(cfg: DictConfig):
+ """
+ 主函数
+
+ 参数:
+ cfg: 配置对象
+ """
+ if cfg.mode == "train":
+ train(cfg)
+ elif cfg.mode == "eval":
+ evaluate(cfg)
+ else:
+ raise ValueError(
+ f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'"
+ )
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py
index 78f381b68e..8d498cb310 100644
--- a/ppsci/arch/__init__.py
+++ b/ppsci/arch/__init__.py
@@ -22,6 +22,7 @@
from ppsci.arch.amgnet import AMGNet # isort:skip
from ppsci.arch.base import Arch # isort:skip
from ppsci.arch.cfdgcn import CFDGCN # isort:skip
+from ppsci.arch.confild import LatentContainer, LossType, SIRENAutodecoder_film, SpacedDiffusion, UNetModel, ModelVarType, ModelMeanType # isort:skip
from ppsci.arch.smc_reac import SuzukiMiyauraModel # isort:skip
from ppsci.arch.chip_deeponets import ChipDeepONets # isort:skip
from ppsci.arch.crystalgraphconvnet import CrystalGraphConvNet # isort:skip
@@ -98,11 +99,15 @@
"GraphCastNet",
"HEDeepONets",
"LorenzEmbedding",
+ "LatentContainer",
"LatentNO",
"LatentNO_time",
"LNO",
+ "LossType",
"MLP",
"ModelList",
+ "ModelVarType",
+ "ModelMeanType",
"ModifiedMLP",
"NowcastNet",
"PhyCRNet",
@@ -111,12 +116,15 @@
"PrecipNet",
"RosslerEmbedding",
"SFNONet",
+ "SIRENAutodecoder_film",
+ "SpacedDiffusion",
"SPINN",
"TFNO1dNet",
"TFNO2dNet",
"TFNO3dNet",
"Transformer",
"UNetEx",
+ "UNetModel",
"UNONet",
"USCNN",
"VelocityDiscriminator",
diff --git a/ppsci/arch/confild.py b/ppsci/arch/confild.py
new file mode 100644
index 0000000000..873a73371d
--- /dev/null
+++ b/ppsci/arch/confild.py
@@ -0,0 +1,1844 @@
+import math
+import enum
+from collections import OrderedDict
+from abc import abstractmethod
+import numpy as np
+import paddle
+
+DEFAULT_W0 = 30.0
+
+###################### ConFILD Model #######################
+class Swish(paddle.nn.Layer):
+ """
+ Swish activation function: f(x) = x * sigmoid(x).
+
+ A smooth, non-monotonic activation function that has been shown to work
+ better than ReLU on deeper models across a number of challenging datasets.
+ """
+ def __init__(self):
+ super().__init__()
+ self.Sigmoid = paddle.nn.Sigmoid()
+
+ def forward(self, x):
+ """
+ Apply Swish activation.
+
+ Args:
+ x (paddle.Tensor): Input tensor.
+
+ Returns:
+ paddle.Tensor: Output tensor with same shape as input.
+ """
+ return x * self.Sigmoid(x)
+
+
+class Sine(paddle.nn.Layer):
+ """
+ Sine activation function for SIREN (Sinusoidal Representation Networks).
+
+ Args:
+ w0 (float, optional): Frequency parameter for sine activation. Defaults to DEFAULT_W0 (30.0).
+ """
+ def __init__(self, w0=DEFAULT_W0):
+ self.w0 = w0
+ super().__init__()
+
+ def forward(self, input):
+ """
+ Apply sine activation with frequency modulation.
+
+ Args:
+ input (paddle.Tensor): Input tensor.
+
+ Returns:
+ paddle.Tensor: sin(w0 * input).
+ """
+ return paddle.sin(x=self.w0 * input)
+
+
+def sine_init(m, w0=DEFAULT_W0):
+ """
+ Weight initialization for SIREN hidden layers.
+
+ Initializes weights uniformly in [-√(6/n)/w0, √(6/n)/w0] where n is input dimension.
+ This initialization is critical for maintaining stable signal propagation in SIREN networks.
+
+ Args:
+ m (paddle.nn.Layer): Layer to initialize (must have 'weight' attribute).
+ w0 (float, optional): Frequency parameter. Defaults to DEFAULT_W0.
+ """
+ with paddle.no_grad():
+ if hasattr(m, "weight"):
+ num_input = m.weight.shape[-1]
+ m.weight.uniform_(
+ min=-math.sqrt(6 / num_input) / w0, max=math.sqrt(6 / num_input) / w0
+ )
+
+
+def first_layer_sine_init(m):
+ """
+ Weight initialization for SIREN first layer.
+
+ Initializes weights uniformly in [-1/n, 1/n] where n is input dimension.
+ Different from hidden layers to handle raw coordinate inputs properly.
+
+ Args:
+ m (paddle.nn.Layer): Layer to initialize (must have 'weight' attribute).
+ """
+ with paddle.no_grad():
+ if hasattr(m, "weight"):
+ num_input = m.weight.shape[-1]
+ m.weight.uniform_(min=-1 / num_input, max=1 / num_input)
+
+
+def __check_Linear_weight(m):
+ if isinstance(m, paddle.nn.Linear):
+ if hasattr(m, "weight"):
+ return True
+ return False
+
+
+def init_weights_normal(m):
+ if __check_Linear_weight(m):
+ init_KaimingNormal = paddle.nn.initializer.KaimingNormal(
+ nonlinearity="relu", negative_slope=0.0
+ )
+ init_KaimingNormal(m.weight)
+
+
+def init_weights_selu(m):
+ if __check_Linear_weight(m):
+ num_input = m.weight.shape[-1]
+ init_Normal = paddle.nn.initializer.Normal(std=1 / math.sqrt(num_input))
+ init_Normal(m.weight)
+
+
+def init_weights_elu(m):
+ if __check_Linear_weight(m):
+ num_input = m.weight.shape[-1]
+ init_Normal = paddle.nn.initializer.Normal(
+ std=math.sqrt(1.5505188080679277) / math.sqrt(num_input)
+ )
+ init_Normal(m.weight)
+
+
+def init_weights_xavier(m):
+ if __check_Linear_weight(m):
+ init_XavierNormal = paddle.nn.initializer.XavierNormal()
+ init_XavierNormal(m.weight)
+
+
+NLS_AND_INITS = {
+ "sine": (Sine(), sine_init, first_layer_sine_init),
+ "relu": (paddle.nn.ReLU(), init_weights_normal, None),
+ "sigmoid": (paddle.nn.Sigmoid(), init_weights_xavier, None),
+ "tanh": (paddle.nn.Tanh(), init_weights_xavier, None),
+ "selu": (paddle.nn.SELU(), init_weights_selu, None),
+ "softplus": (paddle.nn.Softplus(), init_weights_normal, None),
+ "elu": (paddle.nn.ELU(), init_weights_elu, None),
+ "swish": (Swish(), init_weights_xavier, None),
+}
+
+
+class BatchLinear(paddle.nn.Linear):
+ """
+ Batch-wise linear transformation layer that supports manual parameter injection.
+
+ This layer extends paddle.nn.Linear to allow passing parameters explicitly,
+ which is useful for meta-learning and hypernetwork applications.
+
+ Args:
+ in_features (int): Size of input features.
+ out_features (int): Size of output features.
+
+ Note:
+ - Weight shape: (out_features, in_features)
+ - Bias shape: (out_features,)
+ """
+
+ __doc__ = paddle.nn.Linear.__doc__
+
+ def forward(self, input, params=None):
+ """
+ Forward pass with optional external parameters.
+
+ Args:
+ input (paddle.Tensor): Input tensor of shape (..., in_features).
+ params (OrderedDict, optional): External parameters dict containing 'weight' and optionally 'bias'.
+ If None, uses internal parameters. Defaults to None.
+
+ Returns:
+ paddle.Tensor: Output tensor of shape (..., out_features).
+ """
+ if params is None:
+ params = OrderedDict(self.named_parameters())
+ bias = params.get("bias", None)
+ weight = params["weight"]
+
+ output = paddle.matmul(x=input, y=weight)
+ if bias is not None:
+ output += bias.unsqueeze(axis=-2)
+ return output
+
+
+class FeatureMapping:
+ """
+ Feature mapping class for Fourier Feature Networks.
+
+ Supports multiple mapping strategies including Gaussian random Fourier features,
+ positional encoding, and radial basis functions (RBF) for improving coordinate-based
+ neural network representations.
+
+ Reference:
+ Tancik et al. "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains"
+ """
+
+ def __init__(
+ self,
+ in_features,
+ mode="basic",
+ gaussian_mapping_size=256,
+ gaussian_rand_key=0,
+ gaussian_tau=1.0,
+ pe_num_freqs=4,
+ pe_scale=2,
+ pe_init_scale=1,
+ pe_use_nyquist=True,
+ pe_lowest_dim=None,
+ rbf_out_features=None,
+ rbf_range=1.0,
+ rbf_std=0.5,
+ ):
+ """
+ Initialize feature mapping.
+
+ Args:
+ in_features (int): Number of input features.
+ mode (str, optional): Mapping mode. Options: "basic", "gaussian", "positional", "rbf". Defaults to "basic".
+ gaussian_mapping_size (int, optional): Output dimension for Gaussian mapping. Defaults to 256.
+ gaussian_rand_key (int, optional): Random seed for Gaussian mapping. Defaults to 0.
+ gaussian_tau (float, optional): Standard deviation for Gaussian mapping. Defaults to 1.0.
+ pe_num_freqs (int, optional): Number of frequency bands for positional encoding. Defaults to 4.
+ pe_scale (int, optional): Base scale for frequencies in positional encoding. Defaults to 2.
+ pe_init_scale (int, optional): Initial scale multiplier for positional encoding. Defaults to 1.
+ pe_use_nyquist (bool, optional): Use Nyquist frequency to determine num_freqs. Defaults to True.
+ pe_lowest_dim (int, optional): Lowest dimension for Nyquist calculation. Defaults to None.
+ rbf_out_features (int, optional): Number of RBF centers. Defaults to None.
+ rbf_range (float, optional): Range for RBF center initialization. Defaults to 1.0.
+ rbf_std (float, optional): Standard deviation for RBF kernels. Defaults to 0.5.
+ """
+ self.mode = mode
+ if mode == "basic":
+ self.B = np.eye(in_features)
+ elif mode == "gaussian":
+ rng = np.random.default_rng(gaussian_rand_key)
+ self.B = rng.normal(
+ loc=0.0, scale=gaussian_tau, size=(gaussian_mapping_size, in_features)
+ )
+ elif mode == "positional":
+ if pe_use_nyquist == "True" and pe_lowest_dim:
+ pe_num_freqs = self.get_num_frequencies_nyquist(pe_lowest_dim)
+ self.B = pe_init_scale * np.vstack(
+ [(pe_scale**i * np.eye(in_features)) for i in range(pe_num_freqs)]
+ )
+ self.dim = tuple(self.B.shape)[0] * 2
+ elif mode == "rbf":
+ self.centers = paddle.base.framework.EagerParamBase.from_tensor(
+ tensor=paddle.empty(
+ shape=(rbf_out_features, in_features), dtype="float32"
+ )
+ )
+ self.sigmas = paddle.base.framework.EagerParamBase.from_tensor(
+ tensor=paddle.empty(shape=rbf_out_features, dtype="float32")
+ )
+ init_Uniform = paddle.nn.initializer.Uniform(
+ low=-1 * rbf_range, high=rbf_range
+ )
+ init_Uniform(self.centers)
+ init_Constant = paddle.nn.initializer.Constant(value=rbf_std)
+ init_Constant(self.sigmas)
+
+ def __call__(self, input):
+ if self.mode in ["basic", "gaussian", "positional"]:
+ return self.fourier_mapping(input, self.B)
+ elif self.mode == "rbf":
+ return self.rbf_mapping(input)
+
+ def get_num_frequencies_nyquist(self, samples):
+ nyquist_rate = 1 / (2 * (2 * 1 / samples))
+ return int(math.floor(math.log(nyquist_rate, 2)))
+
+ @staticmethod
+ def fourier_mapping(x, B):
+ """
+ Apply Fourier feature mapping: [sin(2πxB^T), cos(2πxB^T)].
+
+ Args:
+ x (paddle.Tensor): Input coordinates of shape (..., in_features).
+ B (np.ndarray): Frequency matrix of shape (mapping_size, in_features).
+
+ Returns:
+ paddle.Tensor: Fourier features of shape (..., 2 * mapping_size).
+ """
+ if B is None:
+ return x
+ else:
+ B = paddle.to_tensor(data=B, dtype="float32", place=x.place)
+ x_proj = 2.0 * np.pi * x @ B.T
+ return paddle.concat(
+ x=[paddle.sin(x=x_proj), paddle.cos(x=x_proj)], axis=-1
+ )
+
+ def rbf_mapping(self, x):
+ size = tuple(x.shape)[:-1] + tuple(self.centers.shape)
+ x = x.unsqueeze(axis=-2).expand(shape=size)
+ distances = (x - self.centers).pow(y=2).sum(axis=-1) * self.sigmas
+ return self.gaussian(distances)
+
+ @staticmethod
+ def gaussian(alpha):
+ phi = paddle.exp(x=-1 * alpha.pow(y=2))
+ return phi
+
+
+class SIRENAutodecoder_film(paddle.nn.Layer):
+ """
+ SIREN (Sinusoidal Representation Networks) with FiLM conditioning for autodecoding.
+
+ This architecture uses sine activations and latent code modulation (FiLM) for
+ implicit neural representations. It takes both coordinate inputs and latent codes,
+ making it suitable for learning multiple shapes/scenes with a single network.
+
+ Reference:
+ Sitzmann et al. "Implicit Neural Representations with Periodic Activation Functions" (NeurIPS 2020)
+
+ Args:
+ input_keys (Tuple[str, ...], optional): Keys to get input tensors from dict. First key for coordinates, second for latents.
+ output_keys (Tuple[str, ...], optional): Keys to save output tensors into dict.
+ in_coord_features (int, optional): Number of input coordinate features (e.g., 2 for 2D, 3 for 3D).
+ in_latent_features (int, optional): Number of latent features for conditioning.
+ out_features (int, optional): Number of output features (e.g., 3 for RGB).
+ num_hidden_layers (int, optional): Number of hidden layers.
+ hidden_features (int, optional): Number of hidden layer features.
+ outermost_linear (bool, optional): Whether to use linear layer at output. Defaults to False.
+ nonlinearity (str, optional): Activation function. Options: "sine", "relu", "tanh", etc. Defaults to "sine".
+ weight_init (Callable, optional): Custom weight initialization function. Defaults to None.
+ bias_init (Callable, optional): Custom bias initialization function. Defaults to None.
+ premap_mode (str, optional): Feature mapping mode before network. Options: "gaussian", "positional", "rbf". Defaults to None.
+
+ Examples:
+ >>> import ppsci
+ >>> model = ppsci.arch.SIRENAutodecoder_film(
+ ... input_keys=["coords", "latents"],
+ ... output_keys=("output",),
+ ... in_coord_features=2,
+ ... in_latent_features=128,
+ ... out_features=3,
+ ... num_hidden_layers=10,
+ ... hidden_features=128,
+ ... )
+ >>> input_data = {
+ ... "coords": paddle.randn([1000, 2]),
+ ... "latents": paddle.randn([1000, 128])
+ ... }
+ >>> out_dict = model(input_data)
+ >>> print(out_dict["output"].shape)
+ [1000, 3]
+ """
+
+ def __init__(
+ self,
+ input_keys,
+ output_keys,
+ in_coord_features,
+ in_latent_features,
+ out_features,
+ num_hidden_layers,
+ hidden_features,
+ outermost_linear=False,
+ nonlinearity="sine",
+ weight_init=None,
+ bias_init=None,
+ premap_mode=None,
+ **kwargs,
+ ):
+ super().__init__()
+ self.input_keys = input_keys
+ self.output_keys = output_keys
+
+ self.premap_mode = premap_mode
+ if self.premap_mode is not None:
+ self.premap_layer = FeatureMapping(
+ in_coord_features, mode=premap_mode, **kwargs
+ )
+ in_coord_features = self.premap_layer.dim
+ self.first_layer_init = None
+ self.nl, nl_weight_init, first_layer_init = NLS_AND_INITS[nonlinearity]
+ if weight_init is not None:
+ self.weight_init = weight_init
+ else:
+ self.weight_init = nl_weight_init
+ self.net1 = paddle.nn.LayerList(
+ sublayers=[BatchLinear(in_coord_features, hidden_features)]
+ + [
+ BatchLinear(hidden_features, hidden_features)
+ for i in range(num_hidden_layers)
+ ]
+ + [BatchLinear(hidden_features, out_features)]
+ )
+ self.net2 = paddle.nn.LayerList(
+ sublayers=[
+ BatchLinear(in_latent_features, hidden_features, bias_attr=False)
+ for i in range(num_hidden_layers + 1)
+ ]
+ )
+ if self.weight_init is not None:
+ self.net1.apply(self.weight_init)
+ self.net2.apply(self.weight_init)
+ if first_layer_init is not None:
+ self.net1[0].apply(first_layer_init)
+ self.net2[0].apply(first_layer_init)
+ if bias_init is not None:
+ self.net2.apply(bias_init)
+
+ def forward(self, input_data):
+ coords = input_data[self.input_keys[0]]
+ latents = input_data[self.input_keys[1]]
+ if self.premap_mode is not None:
+ x = self.premap_layer(coords)
+ else:
+ x = coords
+
+ for i in range(len(self.net1) - 1):
+ x = self.net1[i](x) + self.net2[i](latents)
+ x = self.nl(x)
+ x = self.net1[-1](x)
+ return {self.output_keys[0]: x}
+
+ def disable_gradient(self):
+ for param in self.parameters():
+ param.stop_gradient = not False
+
+
+class LatentContainer(paddle.nn.Layer):
+ """
+ Learnable latent code container for autodecoding applications.
+
+ This module stores and retrieves per-sample latent codes, which can be used
+ for representing multiple instances (shapes, scenes) with a single decoder network.
+ Supports multi-GPU training and different dimensional arrangements.
+
+ Reference:
+ Park et al. "DeepSDF: Learning Continuous Signed Distance Functions for Shape Representation" (CVPR 2019)
+
+ Args:
+ input_keys (Tuple[str, ...], optional): Key to get batch indices from dict. Defaults to ("input",).
+ output_keys (Tuple[str, ...], optional): Key to save latent codes into dict. Defaults to ("output",).
+ N_samples (int, optional): Total number of samples/instances in dataset. Defaults to None.
+ N_features (int, optional): Dimension of latent codes. Defaults to None.
+ dims (int, optional): Number of spatial dimensions (for proper broadcasting). Defaults to None.
+ lumped (bool, optional): If True, adds single dimension; if False, adds dims dimensions. Defaults to False.
+
+ Examples:
+ >>> import ppsci
+ >>> import paddle
+ >>> model = ppsci.arch.LatentContainer(
+ ... N_samples=1600,
+ ... N_features=128,
+ ... dims=2,
+ ... lumped=True
+ ... )
+ >>> batch_indices = paddle.randint(0, 1600, [32], dtype='int64')
+ >>> input_dict = {"input": batch_indices}
+ >>> out_dict = model(input_dict)
+ >>> print(out_dict["output"].shape)
+ [32, 1, 128]
+ """
+
+ def __init__(
+ self,
+ input_keys=("input",),
+ output_keys=("output",),
+ N_samples=None,
+ N_features=None,
+ dims=None,
+ lumped=False,
+ ):
+ super().__init__()
+ self.input_keys = input_keys
+ self.output_keys = output_keys
+ self.dims = [1] * dims if not lumped else [1]
+ self.expand_dims = " ".join(["1" for _ in range(dims)]) if not lumped else "1"
+ self.expand_dims = f"N f -> N {self.expand_dims} f"
+ self.latents = self.create_parameter(
+ shape=(N_samples, N_features),
+ dtype="float32",
+ default_initializer=paddle.nn.initializer.Constant(0.0),
+ )
+
+ def forward(self, batch_ids):
+ x = batch_ids[self.input_keys[0]]
+ selected_latents = paddle.gather(self.latents, x)
+ if len(selected_latents.shape) > 1:
+ getShape = (
+ [tuple(selected_latents.shape)[0]]
+ + self.dims
+ + [tuple(selected_latents.shape)[1]]
+ )
+ else:
+ getShape = [-1] + self.dims
+ expanded_latents = selected_latents.reshape(getShape)
+ return {self.output_keys[0]: expanded_latents}
+
+###################### GaussianDiffusion Model #######################
+class ModelVarType(enum.Enum):
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ res = (
+ paddle.to_tensor(data=arr)[timesteps]
+ .astype(dtype="float32")
+ )
+ while len(tuple(res.shape)) < len(broadcast_shape):
+ res = res[..., None]
+ return res.expand(shape=broadcast_shape)
+
+
+def split(x, num_or_sections, axis=0):
+ if isinstance(num_or_sections, int):
+ return paddle.split(x, x.shape[axis]//num_or_sections, axis)
+ else:
+ return paddle.split(x, num_or_sections, axis)
+
+
+class ModelMeanType(enum.Enum):
+ PREVIOUS_X = enum.auto()
+ START_X = enum.auto()
+ EPSILON = enum.auto()
+
+
+def mean_flat(tensor):
+ return paddle.mean(tensor, axis=list(range(1, len(tensor.shape))))
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, paddle.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for th.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, paddle.Tensor) else paddle.to_tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + paddle.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * paddle.exp(-logvar2)
+ )
+
+
+class GaussianDiffusion:
+ """
+ Gaussian diffusion process for denoising diffusion probabilistic models (DDPM).
+
+ Implements the forward diffusion process q(x_t|x_0) and reverse denoising process p(x_{t-1}|x_t).
+ Supports various parameterizations (epsilon, x_0, x_{t-1}) and variance schedules.
+
+ Reference:
+ Ho et al. "Denoising Diffusion Probabilistic Models" (NeurIPS 2020)
+ Nichol & Dhariwal "Improved Denoising Diffusion Probabilistic Models" (ICML 2021)
+
+ Args:
+ betas (np.ndarray): Noise schedule β_t for t=0,...,T-1.
+ model_mean_type (ModelMeanType): Parameterization of model output.
+ model_var_type (ModelVarType): Variance parameterization (fixed or learned).
+ loss_type (LossType): Loss function type (MSE, KL, etc.).
+ rescale_timesteps (bool, optional): Rescale timesteps to [0, 1000]. Defaults to False.
+ """
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type,
+ rescale_timesteps=False,
+ ):
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+ self.rescale_timesteps = rescale_timesteps
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ self.betas = betas
+ assert len(tuple(betas.shape)) == 1, "betas must be 1-D"
+ assert (betas > 0).astype("bool").all() and (betas <= 1).astype("bool").all()
+
+ self.num_timesteps = int(tuple(betas.shape)[0])
+
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert tuple(self.alphas_cumprod_prev.shape) == (self.num_timesteps,)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
+ self.posterior_variance = (
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ self.posterior_log_variance_clipped = np.log(
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
+ )
+ self.posterior_mean_coef1 = (
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ self.posterior_mean_coef2 = (
+ (1.0 - self.alphas_cumprod_prev)
+ * np.sqrt(alphas)
+ / (1.0 - self.alphas_cumprod)
+ )
+
+ def q_mean_variance(self, x_start, t):
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ if noise is None:
+ noise = paddle.randn(x_start.shape)
+
+ sqrt_alpha_cumprod_t = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
+ sqrt_one_minus_alpha_cumprod_t = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+
+ return sqrt_alpha_cumprod_t * x_start + sqrt_one_minus_alpha_cumprod_t * noise
+
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
+ assert tuple(x_t.shape) == tuple(xprev.shape)
+ return (
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, tuple(x_t.shape))
+ * xprev
+ - _extract_into_tensor(
+ self.posterior_mean_coef2 / self.posterior_mean_coef1,
+ t,
+ tuple(x_t.shape),
+ )
+ * x_t
+ )
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert tuple(x_t.shape) == tuple(eps.shape)
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, tuple(x_t.shape))
+ * x_t
+ - _extract_into_tensor(
+ self.sqrt_recipm1_alphas_cumprod, t, tuple(x_t.shape)
+ )
+ * eps
+ )
+
+ def p_mean_variance(
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
+ ):
+ if model_kwargs is None:
+ model_kwargs = {}
+ B, C = tuple(x.shape)[:2]
+ assert tuple(t.shape) == (B,)
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ assert tuple(model_output.shape) == (B, C * 2, *tuple(x.shape)[2:])
+ model_output, model_var_values = split(
+ x=model_output, num_or_sections=C, axis=1
+ )
+ if self.model_var_type == ModelVarType.LEARNED:
+ model_log_variance = model_var_values
+ model_variance = paddle.exp(x=model_log_variance)
+ else:
+ min_log = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, tuple(x.shape)
+ )
+ max_log = _extract_into_tensor(np.log(self.betas), t, tuple(x.shape))
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = paddle.exp(x=model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ ModelVarType.FIXED_LARGE: (
+ np.append(self.posterior_variance[1], self.betas[1:]),
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, tuple(x.shape))
+ model_log_variance = _extract_into_tensor(
+ model_log_variance, t, tuple(x.shape)
+ )
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clip(min=-1, max=1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
+ )
+ model_mean = model_output
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
+ )
+ model_mean, _, _ = self.q_posterior_mean_variance(
+ x_start=pred_xstart, x_t=x, t=t
+ )
+ else:
+ raise NotImplementedError(self.model_mean_type)
+ assert (
+ tuple(model_mean.shape)
+ == tuple(model_log_variance.shape)
+ == tuple(pred_xstart.shape)
+ == tuple(x.shape)
+ )
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ }
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ assert tuple(x_start.shape) == tuple(x_t.shape)
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, tuple(x_t.shape))
+ * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, tuple(x_t.shape)) * x_t
+ )
+ posterior_variance = _extract_into_tensor(
+ self.posterior_variance, t, tuple(x_t.shape)
+ )
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, tuple(x_t.shape)
+ )
+ assert (
+ tuple(posterior_mean.shape)[0]
+ == tuple(posterior_variance.shape)[0]
+ == tuple(posterior_log_variance_clipped.shape)[0]
+ == tuple(x_start.shape)[0]
+ )
+ return (posterior_mean, posterior_variance, posterior_log_variance_clipped)
+
+ def _scale_timesteps(self, t):
+ if self.rescale_timesteps:
+ return t.astype(dtype="float32") * (1000.0 / self.num_timesteps)
+ return t
+
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ if model_kwargs is None:
+ model_kwargs = {}
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
+ new_mean = p_mean_var["mean"].astype(dtype="float32") + p_mean_var[
+ "variance"
+ ] * gradient.astype(dtype="float32")
+ return new_mean
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ if model_kwargs is None:
+ model_kwargs = {}
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
+ x, self._scale_timesteps(t), **model_kwargs
+ )
+
+ out = p_mean_var.copy()
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
+ out["mean"], _, _ = self.q_posterior_mean_variance(
+ x_start=out["pred_xstart"], x_t=x, t=t
+ )
+ return out
+
+ def p_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ noise = paddle.randn(shape=x.shape, dtype=x.dtype)
+ nonzero_mask = (
+ (t != 0).astype(dtype="float32").reshape([-1, *([1] * (len(tuple(x.shape)) - 1))])
+ )
+ if cond_fn is not None:
+ out["mean"] = self.condition_mean(
+ cond_fn, out, x, t, model_kwargs=model_kwargs
+ )
+ sample = (
+ out["mean"] + nonzero_mask * paddle.exp(x=0.5 * out["log_variance"]) * noise
+ )
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ ):
+ final = sample
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = paddle.randn(shape=shape)
+ indices = list(range(self.num_timesteps))[::-1]
+ if progress:
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+ for i in indices:
+ t = paddle.to_tensor(data=[i] * shape[0])
+ with paddle.no_grad():
+ out = self.p_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, valid=False):
+ if model_kwargs is None:
+ model_kwargs = {}
+ if noise is None:
+ noise = paddle.randn(x_start.shape)
+
+ x_t = self.q_sample(x_start=x_start, t=t, noise=noise)
+ # terms = {}
+ # model_output = model(x_t, t)
+
+ # # Handle different model outputs
+ # if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ # assert model_output.shape[1] == 2 * x_start.shape[1], "Output channels must be 2x input channels"
+ # model_output, model_var_values = split(model_output, 2, axis=1)
+
+ # Calculate the MSE loss for epsilon prediction
+ terms = {}
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] = self._vb_terms_bpd(
+ model=model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=True,
+ model_kwargs=model_kwargs,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] *= self.num_timesteps
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
+ model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, C = x_t.shape[:2]
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
+ model_output, model_var_values = split(model_output, C, axis=1)
+ # Learn the variance using the variational bound, but don't let
+ # it affect our mean prediction.
+ frozen_out = paddle.cat([model_output.detach(), model_var_values], dim=1)
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out: r,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=True,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+
+ target = {
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )[0],
+ ModelMeanType.START_X: x_start,
+ ModelMeanType.EPSILON: noise,
+ }[self.model_mean_type]
+ assert model_output.shape == target.shape == x_start.shape
+
+ if valid == False:
+ terms["mse"] = mean_flat((target - model_output) ** 2)
+ if "vb" in terms:
+ terms["loss"] = terms["mse"] + terms["vb"]
+ else:
+ terms["loss"] = terms["mse"]
+ else:
+ terms["valid_mse"] = mean_flat((target - model_output) ** 2)
+ if "vb" in terms:
+ terms["loss"] = terms["valid_mse"] + terms["vb"]
+ else:
+ terms["loss"] = terms["valid_mse"]
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ return terms
+
+ def _vb_terms_bpd(
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
+ ):
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )
+ out = self.p_mean_variance(
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
+ )
+ kl = normal_kl(
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
+ )
+ kl = mean_flat(kl) / np.log(2.0)
+
+ decoder_nll = -discretized_gaussian_log_likelihood(
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
+ )
+ assert decoder_nll.shape == x_start.shape
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
+
+ # At the first timestep return the decoder NLL,
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
+ output = paddle.where((t == 0), decoder_nll, kl)
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
+
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales):
+ assert x.shape == means.shape == log_scales.shape
+ centered_x = x - means
+ inv_stdv = paddle.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
+ cdf_plus = approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
+ cdf_min = approx_standard_normal_cdf(min_in)
+ log_cdf_plus = paddle.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = paddle.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = paddle.where(
+ x < -0.999,
+ log_cdf_plus,
+ paddle.where(x > 0.999, log_one_minus_cdf_min, paddle.log(cdf_delta.clamp(min=1e-12))),
+ )
+ assert log_probs.shape == x.shape
+ return log_probs
+
+
+def approx_standard_normal_cdf(x):
+ """
+ A fast approximation of the cumulative distribution function of the
+ standard normal.
+ """
+ return 0.5 * (1.0 + paddle.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * paddle.pow(x, 3))))
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = (
+ enum.auto()
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ Accelerated diffusion process that skips timesteps for faster sampling.
+
+ Implements DDIM-style sampling by using a subset of timesteps from the original
+ diffusion process, enabling faster inference without retraining the model.
+
+ Reference:
+ Song et al. "Denoising Diffusion Implicit Models" (ICLR 2021)
+
+ Args:
+ use_timesteps (Sequence[int]): Collection of timesteps to retain from original process
+ (e.g., [0, 10, 20, ..., 1000] for 100-step sampling).
+ **kwargs: Additional arguments for base GaussianDiffusion (betas, model_mean_type, etc.).
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.timestep_map = []
+ self.original_num_steps = len(kwargs["betas"])
+ base_diffusion = GaussianDiffusion(**kwargs)
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ self.timestep_map.append(i)
+ kwargs["betas"] = np.array(new_betas)
+ super().__init__(**kwargs)
+
+ def p_mean_variance(self, model, *args, **kwargs):
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ def training_losses(self, model, *args, **kwargs):
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
+
+ def condition_mean(self, cond_fn, *args, **kwargs):
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def condition_score(self, cond_fn, *args, **kwargs):
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def _wrap_model(self, model):
+ if isinstance(model, _WrappedModel):
+ return model
+ return _WrappedModel(
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
+ )
+
+ def _scale_timesteps(self, t):
+ return t
+
+
+class _WrappedModel:
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, ts, **kwargs):
+ map_tensor = paddle.to_tensor(
+ data=self.timestep_map, dtype=ts.dtype#, place=ts.place
+ )
+ new_ts = map_tensor[ts]
+ if self.rescale_timesteps:
+ new_ts = new_ts.astype(dtype="float32") * (1000.0 / self.original_num_steps)
+ return self.model(x, new_ts, **kwargs)
+
+
+###################### UNET Model #######################
+def conv_nd(dims, *args, **kwargs):
+ if dims == 1:
+ return paddle.nn.Conv1D(*args, **kwargs)
+ elif dims == 2:
+ return paddle.nn.Conv2D(*args, **kwargs)
+ elif dims == 3:
+ return paddle.nn.Conv3D(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ return paddle.nn.Linear(*args, **kwargs)
+
+
+class TimestepBlock(paddle.nn.Layer):
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+ pass
+
+
+class ResBlock(TimestepBlock):
+ """
+ Residual block with timestep embedding for diffusion models.
+
+ Implements a residual connection with two convolutional layers, timestep conditioning,
+ and optional up/downsampling. Supports FiLM-style adaptive normalization.
+
+ Args:
+ channels (int): Number of input channels.
+ emb_channels (int): Number of timestep embedding channels.
+ dropout (float): Dropout probability.
+ out_channels (int, optional): Number of output channels. Defaults to channels.
+ use_conv (bool, optional): Use conv for skip connection if channels differ. Defaults to False.
+ use_scale_shift_norm (bool, optional): Use FiLM-style conditioning. Defaults to False.
+ dims (int, optional): Spatial dimensions (1D/2D/3D). Defaults to 2.
+ use_checkpoint (bool, optional): Use gradient checkpointing. Defaults to False.
+ up (bool, optional): Apply upsampling. Defaults to False.
+ down (bool, optional): Apply downsampling. Defaults to False.
+ """
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.in_layers = paddle.nn.Sequential(
+ normalization(channels),
+ paddle.nn.Silu(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+ self.updown = up or down
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = paddle.nn.Identity()
+ self.emb_layers = paddle.nn.Sequential(
+ paddle.nn.Silu(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = paddle.nn.Sequential(
+ normalization(self.out_channels),
+ paddle.nn.Silu(),
+ paddle.nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+ if self.out_channels == channels:
+ self.skip_connection = paddle.nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).astype(h.dtype)
+ while len(tuple(emb_out.shape)) < len(tuple(h.shape)):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ (scale, shift) = paddle.chunk(x=emb_out, chunks=2, axis=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class TimestepEmbedSequential(paddle.nn.Sequential, TimestepBlock):
+ def forward(self, x, emb):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ else:
+ x = layer(x)
+ return x
+
+
+NUM_CLASSES = 1000
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return paddle.nn.AvgPool1D(*args, **kwargs, exclusive=False)
+ elif dims == 2:
+ return paddle.nn.AvgPool2D(*args, **kwargs, exclusive=False)
+ elif dims == 3:
+ return paddle.nn.AvgPool3D(*args, **kwargs, exclusive=False)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class Downsample(paddle.nn.Layer):
+ """
+ Spatial downsampling layer (2x reduction).
+
+ Can use either strided convolution or average pooling for downsampling.
+
+ Args:
+ channels (int): Number of input channels.
+ use_conv (bool): Use strided conv (True) or avg pooling (False).
+ dims (int, optional): Spatial dimensions. Defaults to 2.
+ out_channels (int, optional): Number of output channels. Defaults to channels.
+ """
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ """Apply downsampling."""
+ assert tuple(x.shape)[1] == self.channels
+ return self.op(x)
+
+
+class Upsample(paddle.nn.Layer):
+ """
+ Spatial upsampling layer (2x expansion).
+
+ Uses nearest-neighbor interpolation followed by optional convolution.
+
+ Args:
+ channels (int): Number of input channels.
+ use_conv (bool): Apply convolution after upsampling.
+ dims (int, optional): Spatial dimensions. Defaults to 2.
+ out_channels (int, optional): Number of output channels. Defaults to channels.
+ """
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, x):
+ """Apply upsampling."""
+ assert tuple(x.shape)[1] == self.channels
+ if self.dims == 3:
+ x = paddle.nn.functional.interpolate(
+ x=x,
+ size=(tuple(x.shape)[2], tuple(x.shape)[3] * 2, tuple(x.shape)[4] * 2),
+ mode="nearest",
+ )
+ else:
+ x = paddle.nn.functional.interpolate(x=x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+def count_flops_attn(model, _x, y):
+ b, c, *spatial = tuple(y[0].shape)
+ num_spatial = int(np.prod(spatial))
+ matmul_ops = 2 * b * (num_spatial**2) * c
+ model.total_ops += paddle.to_tensor(data=[matmul_ops], dtype="float64")
+
+
+class QKVAttentionLegacy(paddle.nn.Layer):
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ bs, width, length = tuple(qkv.shape)
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ # split_size: 为 int 时 torch 表示块的大小,paddle 表示块的个数
+ (q, k, v) = split(qkv.reshape((bs * self.n_heads, ch * 3, length)), ch, 1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = paddle.einsum("bct,bcs->bts", q * scale, k * scale)
+ weight = paddle.nn.functional.softmax(
+ x=weight.astype(dtype="float32"), axis=-1
+ ).astype(weight.dtype)
+ a = paddle.einsum("bts,bcs->bct", weight, v)
+ return a.reshape((bs, -1, length))
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(paddle.nn.Layer):
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ bs, width, length = tuple(qkv.shape)
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ (q, k, v) = qkv.chunk(chunks=3, axis=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = paddle.einsum(# 非复数
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ )
+ weight = paddle.nn.functional.softmax(
+ x=weight.astype(dtype="float32"), axis=-1
+ ).astype(weight.dtype)
+ a = paddle.einsum(
+ "bts,bcs->bct", weight, v.reshape((bs * self.n_heads, ch, length))
+ )
+ return a.reshape((bs, -1, length))
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class GroupNorm32(paddle.nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.astype(dtype="float32")).astype(x.dtype)
+
+
+def normalization(channels):
+ return GroupNorm32(32, channels)
+
+
+def zero_module(module):
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def checkpoint(func, inputs, params, flag):
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(paddle.autograd.PyLayer):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ with paddle.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [stop_gradient(x, stop=False) for x in ctx.input_tensors]
+ with paddle.enable_grad():
+ shallow_copies = [x.view_as(other=x) for x in ctx.input_tensors]
+ # print(shallow_copies)
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = paddle.grad(
+ outputs=output_tensors,
+ inputs=ctx.input_tensors + ctx.input_params,
+ grad_outputs=output_grads,
+ allow_unused=True,
+ # retain_graph=True, create_graph=False
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+
+ # 确保将input_grads转换为元组,然后与(None, None)连接
+ # PyLayer要求backward方法返回元组类型
+ # if input_grads:
+ return tuple(input_grads)
+ # else:
+ # return (None, None)
+
+
+def stop_gradient(input, stop):
+ input.stop_gradient = stop
+ return input
+
+
+class AttentionBlock(paddle.nn.Layer):
+ """
+ Self-attention block for spatial feature maps.
+
+ Applies multi-head self-attention over spatial locations in feature maps,
+ allowing the model to capture long-range dependencies.
+
+ Args:
+ channels (int): Number of input/output channels.
+ num_heads (int, optional): Number of attention heads. Defaults to 1.
+ num_head_channels (int, optional): Channels per head (overrides num_heads). Defaults to -1.
+ use_checkpoint (bool, optional): Use gradient checkpointing. Defaults to False.
+ use_new_attention_order (bool, optional): Use optimized attention implementation. Defaults to False.
+ """
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ self.attention = QKVAttentionLegacy(self.num_heads)
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True)
+
+ def _forward(self, x):
+ b, c, *spatial = tuple(x.shape)
+ x = x.reshape((b, c, -1))
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape((b, c, *spatial))
+
+
+def convert_module_to_f16(l):
+ if isinstance(l, (paddle.nn.Conv1D, paddle.nn.Conv2D, paddle.nn.Conv3D)):
+ l.weight.data = l.weight.data.astype(dtype="float16")
+ if l.bias is not None:
+ l.bias.data = l.bias.data.astype(dtype="float16")
+
+
+def convert_module_to_f32(l):
+ if isinstance(l, (paddle.nn.Conv1D, paddle.nn.Conv2D, paddle.nn.Conv3D)):
+ l.weight.data = l.weight.data.astype(dtype="float32")
+ if l.bias is not None:
+ l.bias.data = l.bias.data.astype(dtype="float32")
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings for diffusion models.
+
+ Similar to positional encodings in transformers, but for continuous timesteps.
+ Uses sinusoids of exponentially increasing frequencies.
+
+ Args:
+ timesteps (paddle.Tensor): Timestep values of shape (batch_size,).
+ dim (int): Embedding dimension.
+ max_period (int, optional): Maximum period for sinusoids. Defaults to 10000.
+
+ Returns:
+ paddle.Tensor: Timestep embeddings of shape (batch_size, dim).
+ """
+ half = dim // 2
+ freqs = paddle.exp(
+ x=-math.log(max_period)
+ * paddle.arange(start=0, end=half, dtype="float32")
+ / half
+ )#.to(paddle.CUDAPlace(0))
+ args = timesteps[:, None].astype(dtype="float32") * freqs[None]
+ embedding = paddle.concat(x=[paddle.cos(x=args), paddle.sin(x=args)], axis=-1)
+ if dim % 2:
+ embedding = paddle.concat(
+ x=[embedding, paddle.zeros_like(x=embedding[:, :1])], axis=-1
+ )
+ return embedding
+
+
+class UNetModel(paddle.nn.Layer):
+ """
+ Full UNet model with attention and timestep embedding for diffusion models.
+
+ Implements a U-Net architecture with residual blocks, self-attention at multiple resolutions,
+ and timestep conditioning via adaptive normalization (FiLM). Designed for denoising diffusion
+ probabilistic models (DDPM) and can be conditioned on class labels.
+
+ Reference:
+ Ronneberger et al. "U-Net: Convolutional Networks for Biomedical Image Segmentation" (MICCAI 2015)
+ Dhariwal & Nichol "Diffusion Models Beat GANs on Image Synthesis" (NeurIPS 2021)
+
+ Args:
+ image_size (int): Input image size (maintained for interface compatibility).
+ in_channels (int): Number of channels in input tensor.
+ model_channels (int): Base channel count for model (multiplied by channel_mult).
+ out_channels (int): Number of channels in output tensor.
+ num_res_blocks (int): Number of residual blocks per downsampling level.
+ attention_resolutions (list/tuple): Downsample factors where to apply attention (e.g., [4, 8, 16]).
+ dropout (float, optional): Dropout probability in residual blocks. Defaults to 0.0.
+ channel_mult (tuple, optional): Channel multipliers per level (e.g., (1, 2, 4, 8)). Defaults to (1, 2, 4, 8).
+ conv_resample (bool, optional): Use learned convolutional up/downsampling. Defaults to True.
+ dims (int, optional): Data dimensionality (1=1D, 2=2D, 3=3D). Defaults to 2.
+ num_classes (int, optional): Number of classes for class-conditional generation. Defaults to None.
+ use_checkpoint (bool, optional): Enable gradient checkpointing to save memory. Defaults to False.
+ use_fp16 (bool, optional): Use float16 precision for forward pass. Defaults to False.
+ num_heads (int, optional): Number of attention heads in each attention block. Defaults to 1.
+ num_head_channels (int, optional): Fixed channels per head (overrides num_heads if set). Defaults to -1.
+ num_heads_upsample (int, optional): Attention heads for upsampling blocks. Defaults to -1 (use num_heads).
+ use_scale_shift_norm (bool, optional): Use FiLM-style conditioning in ResBlocks. Defaults to False.
+ resblock_updown (bool, optional): Use ResBlocks for up/downsampling instead of conv layers. Defaults to False.
+ use_new_attention_order (bool, optional): Use optimized QKV attention implementation. Defaults to False.
+
+ Examples:
+ >>> import ppsci
+ >>> import paddle
+ >>> model = ppsci.arch.UNetModel(
+ ... image_size=64,
+ ... in_channels=3,
+ ... model_channels=128,
+ ... out_channels=3,
+ ... num_res_blocks=2,
+ ... attention_resolutions=[8, 16],
+ ... channel_mult=(1, 2, 4, 8),
+ ... num_heads=4,
+ ... )
+ >>> x = paddle.randn([4, 3, 64, 64])
+ >>> t = paddle.randint(0, 1000, [4])
+ >>> out = model(x, t)
+ >>> print(out.shape)
+ [4, 3, 64, 64]
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = "float16" if use_fp16 else "float32"
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ time_embed_dim = model_channels * 4
+ self.time_embed = paddle.nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ paddle.nn.Silu(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ if self.num_classes is not None:
+ self.label_emb = paddle.nn.Embedding(
+ num_embeddings=self.num_classes, embedding_dim=time_embed_dim
+ )
+ ch = input_ch = int(channel_mult[0] * model_channels)
+ self.input_blocks = paddle.nn.LayerList(
+ sublayers=[
+ TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))
+ ]
+ )
+ self._feature_size = ch
+ input_block_chans = [ch]
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = []
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=int(mult * model_channels),
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ )
+ ch = int(mult * model_channels)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+ self.output_blocks = paddle.nn.LayerList(sublayers=[])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = []
+ layers.append(
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=int(model_channels * mult),
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ )
+ ch = int(model_channels * mult)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ self.out = paddle.nn.Sequential(
+ normalization(ch),
+ paddle.nn.Silu(),
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps, y=None):
+ """
+ Apply the model to an input batch.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+ if self.num_classes is not None:
+ assert tuple(y.shape) == (tuple(x.shape)[0],)
+ emb = emb + self.label_emb(y)
+ h = x.astype(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ hs.append(h)
+ h = self.middle_block(h, emb)
+ for module in self.output_blocks:
+ h = paddle.concat(x=[h, hs.pop()], axis=1)
+ h = module(h, emb)
+ h = h.astype(x.dtype)
+ return self.out(h)
\ No newline at end of file