From 716572ee2bf35d5bcd8248d040783631bca3498f Mon Sep 17 00:00:00 2001 From: Ryan Mueller Date: Mon, 10 Nov 2025 19:12:39 -0500 Subject: [PATCH] Added variable output dim to the conditioned channel attention module --- src/Restorer/Cond_NAF.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/Restorer/Cond_NAF.py b/src/Restorer/Cond_NAF.py index 0bb6be5..d254e10 100644 --- a/src/Restorer/Cond_NAF.py +++ b/src/Restorer/Cond_NAF.py @@ -50,10 +50,12 @@ def forward(self, x): class ConditionedChannelAttention(nn.Module): - def __init__(self, dims, cat_dims): + def __init__(self, dims, cat_dims, out_dim=0): super().__init__() in_dim = dims + cat_dims - self.mlp = nn.Sequential(nn.Linear(in_dim, dims)) + if not out_dim: + out_dim = dims + self.mlp = nn.Sequential(nn.Linear(in_dim, out_dim)) self.pool = nn.AdaptiveAvgPool2d(1) def forward(self, x, conditioning): @@ -95,7 +97,7 @@ def forward(self, x): class CondFuser(nn.Module): def __init__(self, chan, cond_chan=1): super().__init__() - self.cca = ConditionedChannelAttention(chan * 2, cond_chan) + self.cca = ConditionedChannelAttention(chan * 2, cond_chan, out_dim=chan) self.sig = nn.Sigmoid() self.sa = nn.Sequential( @@ -107,7 +109,7 @@ def __init__(self, chan, cond_chan=1): def forward(self, x1, x2, cond): x = torch.cat([x1, x2], dim=1) - x2 = 1 * self.sig(self.cca(x)) * self.sa(x) * x2 + x2 = 1 * self.sig(self.cca(x, cond)) * self.sa(x) * x2 return x1 + x2