-
Couldn't load subscription status.
- Fork 313
Description
Hi,
Thank you for the wonderful repository, and I truly appreciate your implementation of PMSN loss—something even the original author did not provide.
As you can see, the loss function for MSN is illustrated in this image:

It consists of two components, each implemented as follows:
loss = torch.mean(torch.sum(torch.log(probs**(-targets)), dim=1))
# Step 4: compute me-max regularizer
rloss = 0.
if me_max:
avg_probs = AllReduce.apply(torch.mean(probs, dim=0))
rloss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))However, the author/implementer also added the following term:
sloss = 0.
if use_entropy:
sloss = torch.mean(torch.sum(torch.log(probs**(-probs)), dim=1))This additional term is not mentioned anywhere in the paper. However, it is actively used in their configuration file (msn_vits16.yaml), where it is set to true and included in the loss function.
In your implementation of MSNLoss and PMSNLoss (as shown in msn_loss.py and pmsn_loss.py),

We do not see this sloss term—it is entirely ignored. I would like to understand why this was omitted. What was your reasoning behind this decision?
Do you think incorporating it could have improved the final results?
Finally, my main question: If we want to follow the approach taken by the author of PMSN (who unfortunately does not respond to emails), what would be the correct choice? Should we simply replace the rloss term with the KL term you provided while removing sloss, or should we keep it?
Looking forward to your insights.