Skip to content

Commit 0b46ab8

Browse files
committed
changed weights to list to consume less memory
1 parent febc940 commit 0b46ab8

File tree

4 files changed

+36
-14
lines changed

4 files changed

+36
-14
lines changed

chebai/models/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
import extras.adamh as f
1212

13+
import extras.weight_loader as e
14+
1315
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
1416

1517
_MODEL_REGISTRY = dict()
@@ -271,6 +273,7 @@ def _execute(
271273
loss_kwargs = loss_kwargs_candidates
272274
#torch.save(loss_data,"loss_data.pt")
273275
loss_kwargs['weights'] = f.create_data_weights(batchsize=len(data['idents']),dim=data['labels'].size(dim=1),weights=data["loss_kwargs"],idents=data["idents"])
276+
#loss_kwargs['weights'] = e.create_data_weights(batchsize=len(data['idents']),dim=data['labels'].size(dim=1),weights=data["loss_kwargs"],idents=data["idents"])
274277

275278
loss_kwargs["current_epoch"] = self.trainer.current_epoch
276279
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)

chebai/preprocessing/datasets/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1182,7 +1182,7 @@ def load_processed_data(
11821182
data_df = self.dynamic_split_dfs[kind]
11831183
data = data_df.to_dict(orient="records")
11841184
if kind == "train" :
1185-
#f.init_weights()
1185+
# f.init_weights()
11861186
data = f.add_train_weights(data)
11871187
if kind == "validation" :
11881188
data = f.add_val_weights(data)

extras/adamh.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import torch
22
import csv
3+
import numpy
34

45

56
train = 0
67

78

8-
def create_weight(path_to_split="../../split/splits.csv"):
9+
def create_weight(path_to_split="/home/programmer/Bachelorarbeit/split/splits.csv"):
910
weights = {}
1011
with open(path_to_split, 'r') as csvfile:
1112
reader = csv.reader(csvfile)
@@ -17,13 +18,27 @@ def create_weight(path_to_split="../../split/splits.csv"):
1718
#print(row[0])
1819
i = i +1
1920
print(len(weights))
20-
torch.save(weights,"../../weights/init_mh.pt")
21+
torch.save(weights,"/home/programmer/Bachelorarbeit/weights/init_mh.pt")
22+
23+
24+
25+
def new_create_weight(path_to_split="/home/programmer/Bachelorarbeit/split/splits.csv"):
26+
weights = {}
27+
with open(path_to_split, 'r') as csvfile:
28+
reader = csv.reader(csvfile)
29+
i = 0
30+
for row in reader:
31+
if (row[1] == "train") and i > 0:
32+
# print(row[0])
33+
weights[row[0]] = [int(row[0])]* 1528
34+
# print(row[0])
35+
i = i + 1
36+
print(len(weights))
37+
torch.save(weights, "../../weights/init_mh.pt")
38+
2139

2240
def add_train_weights(ids):
2341
d = torch.load("/home/programmer/Bachelorarbeit/weights/init_mh.pt",weights_only=False)
24-
global train
25-
train = train + 1
26-
print(train)
2742
it = 0
2843
for i in ids:
2944
if it % 10000 == 0:
@@ -36,19 +51,23 @@ def add_train_weights(ids):
3651
def add_val_weights(ids):
3752
for i in ids:
3853
weight = 1
39-
i["weight"] = torch.full((1,1528),1)
54+
#i["weight"] = torch.full((1,1528),1)
55+
i["weight"] = [1]*1528
56+
4057
return ids
4158

42-
def create_data_weights(batchsize:int,dim:int,weights:dict[str,torch.Tensor],idents:tuple[int,...])-> torch.tensor:
59+
def create_data_weights(batchsize:int,dim:int,weights:dict[str,list[float,...]],idents:tuple[int,...])-> torch.tensor:
4360
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
4461
weight = None
4562
index = 0
4663
for i in idents:
47-
w = weights[str(i)]
64+
w = torch.Tensor([weights[str(i)],]).to(device)
4865
if weight == None:
4966
weight = w
5067
else:
5168
weight = torch.cat((weight,w),0)
5269
index = index + 1
5370
return weight
5471

72+
#new_create_weight()
73+
#create_weight()

extras/weight_loader.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ def add_train_weights(ids):
110110
it = it +1
111111
return ids
112112

113-
def check_weights(data):
114-
for i in data:
115-
print(f"({i["ident"]} , {i["weight"]}")
113+
#def check_weights(data):
114+
# for i in data:
115+
# print(f"({i["ident"]} , {i["weight"]}")
116116

117117

118118
def init_class_weights(class_path:str,weight_path:str,weight:float):
@@ -154,5 +154,5 @@ def create_weight_class_tensor(batch_size:int)-> torch.Tensor:
154154

155155

156156
#init_class_weights("../../data/chebi_v241/ChEBI50/processed/classes.txt","../../weights/class_first_it.csv",1)
157-
create_class_tensor("../../weights/test.pt")
158-
create_weight_class_tensor(32)
157+
#create_class_tensor("../../weights/test.pt")
158+
#create_weight_class_tensor(32)

0 commit comments

Comments
 (0)