Skip to content

Commit b672453

Browse files
committed
test normalization bug fix
1 parent 372db01 commit b672453

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

tests/test_losses.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def test_output_is_scalar(self):
7070

7171
def test_identical_distributions_low_loss(self):
7272
loss_fn = BhattacharyyaLoss()
73-
signal = torch.abs(torch.randn(2, 1, 50)) + 0.1
73+
# Need T > 1 for meaningful bin-level normalization across tasks
74+
signal = torch.abs(torch.randn(2, 4, 50)) + 0.1
7475
loss = loss_fn(signal, signal.clone())
7576
assert loss.item() < 0.01
7677

0 commit comments

Comments
 (0)