diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py index e7d72c9..f5c41b7 100644 --- a/chebifier/prediction_models/nn_predictor.py +++ b/chebifier/prediction_models/nn_predictor.py @@ -22,6 +22,8 @@ def __init__( self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.init_model(ckpt_path=ckpt_path) + self.model.eval() + self.target_labels = [ line.strip() for line in open(target_labels_path, encoding="utf-8") ] @@ -32,6 +34,7 @@ def init_model(self, ckpt_path: str, **kwargs): "Model initialization must be implemented in subclasses." ) + @torch.inference_mode() def calculate_results(self, batch): collator = self.reader_cls.COLLATOR() dat = self.model._process_batch(collator(batch).to(self.device), 0)