-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
93 lines (76 loc) · 2.31 KB
/
train.py
File metadata and controls
93 lines (76 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import pickle
import os
import sys
import time
import matplotlib.pyplot as plt
import yaml
from scripts.training_loop import train
from models.PUMiNet import *
import torch
import torch.optim as optim
with open('config.yaml', 'r') as file:
config = yaml.safe_load(file)
print(config)
# Load configuration
Epochs = config["epochs"]
in_sample = config["in_sample"]
out_path = config["out_path"]
embed_dim = config["embed_dim"]
num_heads = config["num_heads"]
learning_rate = config["learning_rate"]
lr_step_size = config["lr_step_size"]
lr_step_gamma = config["lr_step_gamma"]
try:
os.mkdir(out_path)
except OSError as error:
print(error)
print("Please try cleaning out your workspace or use new path! :)")
sys.exit(1)
with open( in_sample, 'rb') as f:
data = pickle.load(f)
X_train, y_train, X_val, y_val, X_test, y_test = data
# Loss functions
jet_loss_fn = nn.MSELoss()
trk_loss_fn = nn.BCELoss()
# Get Instance of the model
model = PUMiNet(embed_dim, num_heads)
print(model)
print()
print("Trainable Parameters :", sum(p.numel() for p in model.parameters() if p.requires_grad))
print()
# Pass some data to the model and print outputs
Event_no = 0
Jets = 0
Trk_Jet = 1
Trks = 2
jet_pred, trk_pred = model(X_train[Event_no][Jets],X_train[Event_no][Trk_Jet],X_train[Event_no][Trks])
print("Test Case MSE Loss for jets:", jet_loss_fn(jet_pred,y_train[Event_no][0]))
print("Test Case BCE Loss for trks:", trk_loss_fn(trk_pred,y_train[Event_no][1]))
print()
print("GPU Available: ", torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print()
###################
### Train Model ###
###################
# Initialize
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
loss_fns = [jet_loss_fn, trk_loss_fn]
data = [X_train, y_train, X_val, y_val]
# Train
start_time = time.time()
combined_history = train(model, loss_fns, optimizer, device, data, out_path, lr_step_size, lr_step_gamma, Epochs)
end_time = time.time()
print()
print("Training Time: ", round(end_time-start_time,1),"(s)")
torch.save(model,out_path+'/model_final.torch')
# Plot loss
plt.figure()
plt.plot(combined_history[:,0], label="Train")
plt.plot(combined_history[:,1], label="Val")
plt.title('Loss')
plt.legend()
plt.yscale('log')
plt.savefig(out_path+"/Loss_Curve.png")
#plt.show()