From 3d7d91855bd009fd26c7e46e94d67b542ef8a958 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 22 Nov 2025 21:04:55 +0100 Subject: [PATCH] avoid gradients tracking, for faster inference and less memory consumption --- chebifier/prediction_models/nn_predictor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py index e7d72c9..f5c41b7 100644 --- a/chebifier/prediction_models/nn_predictor.py +++ b/chebifier/prediction_models/nn_predictor.py @@ -22,6 +22,8 @@ def __init__( self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.init_model(ckpt_path=ckpt_path) + self.model.eval() + self.target_labels = [ line.strip() for line in open(target_labels_path, encoding="utf-8") ] @@ -32,6 +34,7 @@ def init_model(self, ckpt_path: str, **kwargs): "Model initialization must be implemented in subclasses." ) + @torch.inference_mode() def calculate_results(self, batch): collator = self.reader_cls.COLLATOR() dat = self.model._process_batch(collator(batch).to(self.device), 0)