Skip to content

Commit 58f4227

Browse files
authored
feat: use device from config file when predict
1 parent 786f593 commit 58f4227

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

bin/predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def main(predict_config: OmegaConf):
4141
if sys.platform != 'win32':
4242
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
4343

44-
device = torch.device("cpu")
44+
device = torch.device(predict_config.get('device', 'cpu'))
4545

4646
train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
4747
with open(train_config_path, 'r') as f:

0 commit comments

Comments
 (0)