From 7c58c6e6ed5d8415318204ff85607c05dd1c8c67 Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Mon, 10 Nov 2025 19:03:32 -0500 Subject: [PATCH] Replacing CondFuser --- src/Restorer/Cond_NAF.py | 53 ++++++++-------------------------------- 1 file changed, 10 insertions(+), 43 deletions(-) diff --git a/src/Restorer/Cond_NAF.py b/src/Restorer/Cond_NAF.py index b87700a..0bb6be5 100644 --- a/src/Restorer/Cond_NAF.py +++ b/src/Restorer/Cond_NAF.py @@ -65,47 +65,6 @@ def forward(self, x, conditioning): return ca -# class CondFuser(nn.Module): -# def __init__(self, chan, cond_chan=1): -# super().__init__() -# self.cca = ConditionedChannelAttention(chan * 2, cond_chan) -# # self.spa = nn.Conv2d( -# # in_channels=chan * 2, -# # out_channels=1, -# # kernel_size=3, -# # padding=1, -# # stride=1, -# # groups=1, -# # bias=True, -# # ) - -# def forward(self, x1, x2, cond): -# x = torch.cat([x1, x2], dim=1) -# x = self.cca(x, cond) * x -# # spa = torch.sigmoid(self.spa(x)) - -# x1, x2 = x.chunk(2, dim=1) -# # return x1 * spa + x2 * (1 - spa) -# return x1 + x2 - - -class CondFuser(nn.Module): - def __init__(self, chan, cond_chan=1): - super().__init__() - self.cca = ConditionedChannelAttention(chan * 2, cond_chan) - self.sig = nn.Sigmoid() - - self.sa = nn.Sequential( - nn.Conv2d(in_channels = 2 * chan, out_channels=chan, kernel_size=3, - padding=1, stride=1, - groups=1, bias=True), - nn.Sigmoid() - ) - - def forward(self, x1, x2, cond): - x = torch.cat([x1, x2], dim=1) - x2 = 1 * self.sig(self.cca(x)) * self.sa(x) * x2 - return x1 + x2 class NKA(nn.Module): def __init__(self, dim, channel_reduction = 8): @@ -137,12 +96,20 @@ class CondFuser(nn.Module): def __init__(self, chan, cond_chan=1): super().__init__() self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + self.sig = nn.Sigmoid() + self.sa = nn.Sequential( + nn.Conv2d(in_channels = 2 * chan, out_channels=chan, kernel_size=3, + padding=1, stride=1, + groups=1, bias=True), + nn.Sigmoid() + ) + def forward(self, x1, x2, cond): x = torch.cat([x1, x2], dim=1) - x = self.cca(x, cond) * x - x1, x2 = x.chunk(2, dim=1) + x2 = 1 * self.sig(self.cca(x)) * self.sa(x) * x2 return x1 + x2 + class CondFuserAdd(nn.Module): def __init__(self, chan, cond_chan=1):