-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathdatasets.py
More file actions
58 lines (44 loc) · 1.83 KB
/
datasets.py
File metadata and controls
58 lines (44 loc) · 1.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
'''
Copyright (c) 2023 University of Southern California
See full notice in LICENSE.md
Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi
Shanechi Lab, University of Southern California
'''
from torch.utils.data import Dataset
import torch
class DFINEDataset(Dataset):
'''
Dataset class for DFINE.
'''
def __init__(self, y, behv=None, mask=None):
'''
Initializer for DFINEDataset. Note that this is a subclass of torch.utils.data.Dataset. \
Parameters:
------------
- y: torch.Tensor, shape: (num_seq, num_steps, dim_y), High dimensional neural observations.
- behv: torch.Tensor, shape: (num_seq, num_steps, dim_behv), Behavior data. None by default.
- mask: torch.Tensor, shape: (num_seq, num_steps, 1), Mask for manifold latent factors which shows whether
observations at each timestep exists (1) or are missing (0).
None by default.
'''
self.y = y
# If behv is not provided, initialize it by zeros.
if behv is None:
self.behv = torch.zeros(y.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1)
else:
self.behv = behv
# If mask is not provided, initialize it by ones.
if mask is None:
self.mask = torch.ones(y.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1)
else:
self.mask = mask
def __len__(self):
'''
Returns the length of the dataset
'''
return self.y.shape[0]
def __getitem__(self, idx):
'''
Returns a tuple of neural observations, behavior and mask segments
'''
return self.y[idx, :, :], self.behv[idx, :, :], self.mask[idx, :, :]