|
3 | 3 | import torch |
4 | 4 | from commode_utils.losses import SequenceCrossEntropyLoss |
5 | 5 | from commode_utils.metrics import SequentialF1Score, ClassificationMetrics |
| 6 | +from commode_utils.metrics.chrF import ChrF |
6 | 7 | from commode_utils.modules import LSTMDecoderStep, Decoder |
7 | 8 | from omegaconf import DictConfig |
8 | 9 | from pytorch_lightning import LightningModule |
@@ -41,6 +42,10 @@ def __init__( |
41 | 42 | f"{holdout}_f1": SequentialF1Score(pad_idx=self.__pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx) |
42 | 43 | for holdout in ["train", "val", "test"] |
43 | 44 | } |
| 45 | + id2label = {v: k for k, v in vocabulary.label_to_id.items()} |
| 46 | + metrics.update( |
| 47 | + {f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self.__pad_idx, eos_idx]) for holdout in ["val", "test"]} |
| 48 | + ) |
44 | 49 | self.__metrics = MetricCollection(metrics) |
45 | 50 |
|
46 | 51 | self._encoder = self._get_encoder(model_config) |
@@ -102,18 +107,18 @@ def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict: |
102 | 107 | target_sequence = batch.labels if step == "train" else None |
103 | 108 | # [seq length; batch size; vocab size] |
104 | 109 | logits, _ = self.logits_from_batch(batch, target_sequence) |
105 | | - loss = self.__loss(logits[1:], batch.labels[1:]) |
| 110 | + result = {f"{step}/loss": self.__loss(logits[1:], batch.labels[1:])} |
106 | 111 |
|
107 | 112 | with torch.no_grad(): |
108 | 113 | prediction = logits.argmax(-1) |
109 | 114 | metric: ClassificationMetrics = self.__metrics[f"{step}_f1"](prediction, batch.labels) |
| 115 | + result.update( |
| 116 | + {f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall} |
| 117 | + ) |
| 118 | + if step != "train": |
| 119 | + result[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"](prediction, batch.labels) |
110 | 120 |
|
111 | | - return { |
112 | | - f"{step}/loss": loss, |
113 | | - f"{step}/f1": metric.f1_score, |
114 | | - f"{step}/precision": metric.precision, |
115 | | - f"{step}/recall": metric.recall, |
116 | | - } |
| 121 | + return result |
117 | 122 |
|
118 | 123 | def training_step(self, batch: BatchedLabeledPathContext, batch_idx: int) -> Dict: # type: ignore |
119 | 124 | result = self._shared_step(batch, "train") |
@@ -143,6 +148,9 @@ def _shared_epoch_end(self, step_outputs: EPOCH_OUTPUT, step: str): |
143 | 148 | f"{step}/recall": metric.recall, |
144 | 149 | } |
145 | 150 | self.__metrics[f"{step}_f1"].reset() |
| 151 | + if step != "train": |
| 152 | + log[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"].compute() |
| 153 | + self.__metrics[f"{step}_chrf"].reset() |
146 | 154 | self.log_dict(log, on_step=False, on_epoch=True) |
147 | 155 |
|
148 | 156 | def training_epoch_end(self, step_outputs: EPOCH_OUTPUT): |
|
0 commit comments