-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_device.py
More file actions
27 lines (22 loc) · 746 Bytes
/
_device.py
File metadata and controls
27 lines (22 loc) · 746 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import os
from enum import Enum
from .device_id import DeviceId
class DeviceException(Exception):
pass
class _Device:
def __init__(self):
self.set(DeviceId.CPU)
def is_gpu(self):
''' Returns `True` if the current device is GPU, `False` otherwise. '''
return self.current() is not DeviceId.CPU
def current(self):
return self._current_device
def set(self, device:DeviceId):
if device == DeviceId.CPU:
os.environ['CUDA_VISIBLE_DEVICES']=''
else:
os.environ['CUDA_VISIBLE_DEVICES']=str(device.value)
import torch
torch.backends.cudnn.benchmark=False
self._current_device = device
return device