-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathdatasets.py
More file actions
79 lines (67 loc) · 2.65 KB
/
datasets.py
File metadata and controls
79 lines (67 loc) · 2.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import json
import os
from torch.utils.data import Dataset
import boto3
from botocore import UNSIGNED
from botocore.client import Config
import s3fs
# Setup for data streaming
BUCKET = "chimera-challenge"
PREFIX = "v2/task3/data/"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED))
fs = s3fs.S3FileSystem(anon=True)
class BaseMedicalDataset(Dataset):
"""Helper class to load JSON files."""
def __init__(self, file_paths):
self.file_paths = file_paths
def __len__(self):
return len(self.file_paths)
def _load_json(self, idx):
with fs.open(self.file_paths[idx], 'r') as f:
return json.load(f)
class ClinicalDataset(BaseMedicalDataset):
"""Dataloader for Clinical data (Dimension: 13)"""
def _process_features(self, data):
# Categorical mappings based on provided JSON structure
sex_map = {"Male": 0, "Female": 1}
tumor_map = {"Primary": 0, "Recurrent": 1}
reTUR_map = {"No": 0, "Yes": 1}
lvi_map = {"No": 0, "Yes": 1}
grade_map = {"G1": 1, "G2": 2, "G3": 3}
# Mapping 13 features
features = [
float(data['age']),
float(sex_map.get(data['sex'], 0)),
0.0 if data['smoking'] == 'No' else 1.0,
float(tumor_map.get(data['tumor'], 0)),
1.0 if "T1" in data['stage'] else 0.0,
1.0 if data['substage'] == "T1e" else 0.0,
float(grade_map.get(data['grade'], 0)),
float(reTUR_map.get(data['reTUR'], 0)),
float(lvi_map.get(data['LVI'], 0)),
1.0 if data['variant'] == "UCC" else 0.0,
float(data['no_instillations']),
float(data['progression']),
float(data['Time_to_prog_or_FUend'])
]
return torch.tensor(features, dtype=torch.float32)
def __getitem__(self, idx):
data = self._load_json(idx)
return self._process_features(data)
class RNADataset(BaseMedicalDataset):
"""Dataloader for RNA data (Dimension: 19359)"""
def __init__(self, file_paths, gene_list=None):
super().__init__(file_paths)
# Ensure consistent gene order across all samples
if gene_list is None:
first_sample = self._load_json(0)
self.gene_list = sorted(first_sample.keys())
else:
self.gene_list = gene_list
def __getitem__(self, idx):
data = self._load_json(idx)
# Extract values in the specific gene order
rna_values = [data.get(gene, 0.0) for gene in self.gene_list]
return torch.tensor(rna_values, dtype=torch.float32)