-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmetrics.py
More file actions
64 lines (47 loc) · 1.76 KB
/
metrics.py
File metadata and controls
64 lines (47 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
'''
Copyright (c) 2023 University of Southern California
See full notice in LICENSE.md
Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi
Shanechi Lab, University of Southern California
'''
from torchmetrics import Metric
import torch
class Mean(Metric):
'''
Mean metric class to log batch-averaged metrics to Tensorboard.
'''
def __init__(self):
'''
Initializer for Mean metric. Note that this class is a subclass of torchmetrics.Metric.
'''
super().__init__(dist_sync_on_step=False)
# Define total sum and number of samples that sum is computed over
self.add_state("sum", default=torch.tensor(0, dtype=torch.float32), dist_reduce_fx="sum")
self.add_state("num_samples", default=torch.tensor(0, dtype=torch.float32), dist_reduce_fx="sum")
def update(self, value, batch_size):
'''
Updates the total sum and number of samples
Parameters:
------------
- value: torch.Tensor, shape: (), Value to add to sum
- batch_size: torch.Tensor, shape: (), Number of samples that 'value' is averaged over
'''
value = value.clone().detach()
batch_size = torch.tensor(batch_size, dtype=torch.float32)
self.sum += value.cpu() * batch_size
self.num_samples += batch_size
def reset(self):
'''
Resets the total sum and number of samples to 0
'''
self.sum = torch.tensor(0, dtype=torch.float32)
self.num_samples = torch.tensor(0, dtype=torch.float32)
def compute(self):
'''
Computes the mean metric.
Returns:
------------
- avg: Average value for the metric
'''
avg = self.sum / self.num_samples
return avg