We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 786f593 commit 58f4227Copy full SHA for 58f4227
bin/predict.py
@@ -41,7 +41,7 @@ def main(predict_config: OmegaConf):
41
if sys.platform != 'win32':
42
register_debug_signal_handlers() # kill -10 <pid> will result in traceback dumped into log
43
44
- device = torch.device("cpu")
+ device = torch.device(predict_config.get('device', 'cpu'))
45
46
train_config_path = os.path.join(predict_config.model.path, 'config.yaml')
47
with open(train_config_path, 'r') as f:
0 commit comments