We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 372db01 commit b672453Copy full SHA for b672453
1 file changed
tests/test_losses.py
@@ -70,7 +70,8 @@ def test_output_is_scalar(self):
70
71
def test_identical_distributions_low_loss(self):
72
loss_fn = BhattacharyyaLoss()
73
- signal = torch.abs(torch.randn(2, 1, 50)) + 0.1
+ # Need T > 1 for meaningful bin-level normalization across tasks
74
+ signal = torch.abs(torch.randn(2, 4, 50)) + 0.1
75
loss = loss_fn(signal, signal.clone())
76
assert loss.item() < 0.01
77
0 commit comments