Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
08586ff
Added files
simonge Sep 16, 2024
df1b38e
Update
simonge Oct 30, 2024
810cc3b
Make canvas name more useful
simonge Mar 11, 2025
0973e5d
Add beam generator macro
simonge Mar 11, 2025
1323fe0
Update analysis
simonge Apr 30, 2025
6329372
Add phasespace scripts
simonge May 2, 2025
67d765c
Change name of functors
simonge May 2, 2025
3afabcb
Remove image and acceptance benchmarks for now
simonge Jun 4, 2025
99633c2
Prepare for CI testing
simonge Jun 4, 2025
0deb701
Remove other benchmarks
simonge Jun 4, 2025
163341b
Add missing config.yml
simonge Jun 4, 2025
7a58355
Correct typo
simonge Jun 4, 2025
5cd62be
Re-enable other benchmarks and update cores
simonge Jun 17, 2025
312ca42
Add some pass fail checks
simonge Jun 18, 2025
02e0c0e
Set sim and analysis dir by variable
simonge Jun 18, 2025
e39d0f9
Revert exclusion of other benchmarks
simonge Jun 19, 2025
67d4206
Review suggestions
simonge Jun 19, 2025
81fad99
Snakefile: fix for out-of-tree running
veprbl Jun 19, 2025
6b21ff1
Snakefile restore indentation
veprbl Jun 19, 2025
b18bf3f
Add header to inputs too
simonge Jun 23, 2025
c7e268f
Add low-q2 phase space electron tests
simonge May 20, 2025
c964d40
Add acceptance sim
simonge Jun 19, 2025
8b98ed6
Make simulation run with correct range and 1000x faster
simonge Jun 20, 2025
31fe8ef
Add outputs from script
simonge Jun 20, 2025
ccddf53
Change code inputs to workflow source path
simonge Jun 23, 2025
f44d690
rename phasespace to acceptance
simonge Jun 23, 2025
4df2119
Remove unused code
simonge Jun 23, 2025
96e8b50
Define both simulations in the yml
simonge Jun 23, 2025
40d4fe8
Merge remote-tracking branch 'origin/master' into beamline_acceptance
simonge Jun 23, 2025
6878af8
Add entry fraction plot
simonge Jun 24, 2025
09f8681
Make filtering more robust
simonge Jun 24, 2025
40e1fa9
Change entry limit warning
simonge Jun 24, 2025
158c8ef
Add reconstruction training based on beampipe exit
simonge Jun 25, 2025
d937fdb
Update model and atempt to setup snakemake
simonge Jun 25, 2025
e5a4ab8
Fix snakemane rule and silence npsim info
simonge Jun 25, 2025
13030cc
Fix snakemake attribute
simonge Jun 25, 2025
02665c3
Scale momentum to unit vector
simonge Jun 26, 2025
16adc7f
Add tensors to device too
simonge Jun 27, 2025
a208cd7
Update benchmarks/beamline/Snakefile
simonge Jul 7, 2025
d2ebe91
Various improvements
simonge Jul 14, 2025
c02b517
Merge branch 'beamline_training' of github.com:eic/detector_benchmark…
simonge Jul 14, 2025
48b1e11
Merge remote-tracking branch 'origin/master' into beamline_training
simonge Jul 14, 2025
ae27686
Lots of updates, filtering of lowq2 hepmc events
simonge Jul 25, 2025
3075944
Add some versitility
simonge Jul 28, 2025
ecd5829
Add resolution test
simonge Jul 28, 2025
91b5041
Change phi plots to degrees
simonge Jul 28, 2025
7723f89
Extra workflow steps and fixes
simonge Jul 29, 2025
322768a
Swapping to huberloss helps
simonge Jul 30, 2025
5613274
Merge remote-tracking branch 'origin/master' into beamline_training
simonge Jul 30, 2025
0238978
Move reconstruction to separate benchmark
simonge Aug 4, 2025
6a27a3e
include lowq2_reconstruction in snakefile
simonge Aug 4, 2025
fafb5e1
Rename processData to cleanData
simonge Aug 4, 2025
ec179b4
Rename ProcessData to LoadData
simonge Aug 4, 2025
5dc2f28
Add lowq2_reconstruction snakefile
simonge Aug 4, 2025
1fcba19
Update CI bits
simonge Aug 4, 2025
c0855b8
Temporarily comment out other benchmarks
simonge Aug 4, 2025
999fb63
Fix namings maybe
simonge Aug 4, 2025
2e50f83
Add failure conditions on resolutions
simonge Aug 4, 2025
e44eac7
Fix benchmark stage
simonge Aug 4, 2025
1fddc48
Fix naming
simonge Aug 4, 2025
7e38cdc
Allow int return to flag fail
simonge Aug 5, 2025
673af74
Separate resolution check from plot creation
simonge Aug 5, 2025
9dbd3aa
Fix naming and add onnx
simonge Aug 12, 2025
b6cde3b
Merge remote-tracking branch 'origin/master' into beamline_training
simonge Aug 12, 2025
7606250
Merge updates from training-CI branch
simonge Aug 19, 2025
06ba1dc
Fix beamline running
simonge Aug 19, 2025
f4f68e9
Require successful beamline benchmark before triggering reconstructio…
simonge Aug 19, 2025
0a89651
Remove optional projection from old inputs
simonge Aug 19, 2025
0f8f68d
Reinclude beamline benchmarks
simonge Aug 19, 2025
349cc53
Re-enable in snakefile too
simonge Aug 20, 2025
bce7cff
Merge branch 'master' into beamline_training
simonge Oct 2, 2025
8fbf0a0
Merge branch 'master' into beamline_training
simonge Oct 16, 2025
8d30eab
Update default
simonge Oct 21, 2025
425b16e
Merge branch 'beamline_training' of github.com:eic/detector_benchmark…
simonge Oct 21, 2025
9da9148
Fix spelling which found its way back in
simonge Oct 21, 2025
ede83e7
Try adding Caching
simonge Oct 21, 2025
c585013
Fix eicrecon options
veprbl Oct 21, 2025
9809b89
Temporary break reconstruction limits
simonge Oct 21, 2025
f82c8ad
Test python requirements
simonge Oct 21, 2025
e0093eb
remove argparse from requirements
simonge Oct 21, 2025
a6356ca
Add onnx dependancies
simonge Oct 22, 2025
04f08eb
Reduce events and epochs for testing
simonge Oct 22, 2025
5379d25
Use static opset version
simonge Oct 22, 2025
237eff3
add missing comma
simonge Oct 22, 2025
5488f48
Try some static versions
simonge Oct 22, 2025
8fa74ee
Update requirements to match eic shell
simonge Oct 22, 2025
b3bceea
Don't require gpu torch
simonge Oct 22, 2025
41af472
Try again
simonge Oct 22, 2025
8ab7b44
Test new workflow with added checks
simonge Oct 23, 2025
b719382
Try with other spelling
simonge Oct 23, 2025
a586f09
Try multiline formatting
simonge Oct 23, 2025
3f3e632
Remove explicit artifact declatation
simonge Oct 23, 2025
5e21f71
Remove other overwriting artifact declaration....
simonge Oct 23, 2025
7442174
Revert "Reduce events and epochs for testing"
simonge Oct 23, 2025
bfffb83
rename retraining jobs
simonge Oct 23, 2025
cf1825e
Revert "Temporary break reconstruction limits"
simonge Oct 23, 2025
1d2abf3
Split collect results
simonge Oct 23, 2025
e4cda34
Fix requirement name
simonge Oct 23, 2025
208dc20
reenable other benchmarks
simonge Oct 23, 2025
373588a
Update names of plots to include Low-Q2
simonge Oct 23, 2025
abca061
Add electron_beamline prefix to beamline outputs
simonge Oct 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ __pycache__/
*$py.class
.ipynb_checkpoints

# results and simulation output
results/
sim_output/
.snakemake/
calibrations/
fieldmaps/

# test for calorimeter
calorimeters/test/
*.d
Expand Down
2 changes: 2 additions & 0 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ include:
- local: 'benchmarks/calo_pid/config.yml'
- local: 'benchmarks/campaign/config.yml'
- local: 'benchmarks/ecal_gaps/config.yml'
- local: 'benchmarks/lowq2_reconstruction/config.yml'
- local: 'benchmarks/tracking_detectors/config.yml'
- local: 'benchmarks/tracking_performances/config.yml'
- local: 'benchmarks/tracking_performances_dis/config.yml'
Expand Down Expand Up @@ -165,6 +166,7 @@ deploy_results:
- "collect_results:campaign"
- "collect_results:ecal_gaps"
- "collect_results:lfhcal"
- "collect_results:lowq2_reconstruction"
- "collect_results:material_scan"
- "collect_results:pid"
- "collect_results:rich"
Expand Down
1 change: 1 addition & 0 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ include: "benchmarks/beamline/Snakefile"
include: "benchmarks/calo_pid/Snakefile"
include: "benchmarks/campaign/Snakefile"
include: "benchmarks/ecal_gaps/Snakefile"
include: "benchmarks/lowq2_reconstruction/Snakefile"
include: "benchmarks/material_scan/Snakefile"
include: "benchmarks/tracking_performances/Snakefile"
include: "benchmarks/tracking_performances_dis/Snakefile"
Expand Down
105 changes: 60 additions & 45 deletions benchmarks/beamline/Snakefile
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
SIMOUTDIR="sim_output/beamline/"
ANALYSISDIR=SIMOUTDIR+"analysis/"

##########################################################################################
### Rules for checking the steering of the electron beam through the magnets
##########################################################################################

rule beamline_steering_sim:
input:
warmup="warmup/epic_ip6_extended.edm4hep.root",
macro=workflow.source_path("beamlineGPS.mac"),
output:
SIMOUTDIR+"beamlineTest{CAMPAIGN}.edm4hep.root",
cache: True
shell:
"""
exec npsim \
Expand All @@ -20,12 +25,41 @@ rule beamline_steering_sim:
--physics.rangecut 100*m
"""

rule beamline_steering_analysis:
input:
warmup="warmup/epic_ip6_extended.edm4hep.root",
script=workflow.source_path("beamlineAnalysis.C"),
header=workflow.source_path("shared_functions.h"),
data=SIMOUTDIR+"beamlineTest{CAMPAIGN}.edm4hep.root",
output:
rootfile=ANALYSISDIR+"electron_beamline_beamlineTestAnalysis{CAMPAIGN}.root",
beamspot_canvas=ANALYSISDIR+"electron_beamline_beamspot_{CAMPAIGN}.png",
x_px_canvas=ANALYSISDIR+"electron_beamline_x_px_{CAMPAIGN}.png",
y_py_canvas=ANALYSISDIR+"electron_beamline_y_py_{CAMPAIGN}.png",
fitted_position_means_stdevs_canvas=ANALYSISDIR+"electron_beamline_fitted_position_means_stdevs_{CAMPAIGN}.png",
fitted_momentum_means_stdevs_canvas=ANALYSISDIR+"electron_beamline_fitted_momentum_means_stdevs_{CAMPAIGN}.png",
pipe_parameter_canvas=ANALYSISDIR+"electron_beamline_pipe_parameter_{CAMPAIGN}.png",
params:
xml=os.getenv("DETECTOR_PATH")+"/epic_ip6_extended.xml",
shell:
"""
root -l -b -q '{input.script}("{input.data}", "{output.rootfile}", "{params.xml}",
"{output.beamspot_canvas}", "{output.x_px_canvas}", "{output.y_py_canvas}",
"{output.fitted_position_means_stdevs_canvas}", "{output.fitted_momentum_means_stdevs_canvas}",
"{output.pipe_parameter_canvas}")'
"""

##########################################################################################
### Rules for checking the acceptance of electrons at each stage of the beamline
##########################################################################################

rule beamline_acceptance_sim:
input:
warmup="warmup/epic_ip6_extended.edm4hep.root",
macro=workflow.source_path("acceptanceGPS.mac"),
output:
SIMOUTDIR+"acceptanceTest{CAMPAIGN}.edm4hep.root",
cache: True
shell:
"""
exec npsim \
Expand All @@ -34,47 +68,22 @@ rule beamline_acceptance_sim:
--enableG4GPS \
--macroFile {input.macro} \
--compactFile $DETECTOR_PATH/epic_ip6_extended.xml \
--printLevel WARNING \
--outputFile {output} \
--physics.rangecut 100*m
"""

rule beamline_steering_analysis:
input:
warmup="warmup/epic_ip6_extended.edm4hep.root",
script=workflow.source_path("beamlineAnalysis.C"),
header=workflow.source_path("shared_functions.h"),
data=SIMOUTDIR+"beamlineTest{CAMPAIGN}.edm4hep.root",
output:
rootfile=ANALYSISDIR+"beamlineTestAnalysis{CAMPAIGN}.root",
beamspot_canvas=ANALYSISDIR+"beamspot_{CAMPAIGN}.png",
x_px_canvas=ANALYSISDIR+"x_px_{CAMPAIGN}.png",
y_py_canvas=ANALYSISDIR+"y_py_{CAMPAIGN}.png",
fitted_position_means_stdevs_canvas=ANALYSISDIR+"fitted_position_means_stdevs_{CAMPAIGN}.png",
fitted_momentum_means_stdevs_canvas=ANALYSISDIR+"fitted_momentum_means_stdevs_{CAMPAIGN}.png",
pipe_parameter_canvas=ANALYSISDIR+"pipe_parameter_{CAMPAIGN}.png",
params:
xml=os.getenv("DETECTOR_PATH")+"/epic_ip6_extended.xml",
shell:
"""
root -l -b -q '{input.script}("{input.data}", "{output.rootfile}", "{params.xml}",
"{output.beamspot_canvas}", "{output.x_px_canvas}", "{output.y_py_canvas}",
"{output.fitted_position_means_stdevs_canvas}", "{output.fitted_momentum_means_stdevs_canvas}",
"{output.pipe_parameter_canvas}")'
"""

rule beamline_acceptance_analysis:
input:
warmup="warmup/epic_ip6_extended.edm4hep.root",
script=workflow.source_path("acceptanceAnalysis.C"),
header=workflow.source_path("shared_functions.h"),
data=SIMOUTDIR+"acceptanceTest{CAMPAIGN}.edm4hep.root",
output:
rootfile=ANALYSISDIR+"acceptanceTestAnalysis{CAMPAIGN}.root",
beampipe_canvas=ANALYSISDIR+"acceptance_in_beampipe_{CAMPAIGN}.png",
etheta_canvas=ANALYSISDIR+"acceptance_energy_theta_{CAMPAIGN}.png",
etheta_acceptance_canvas=ANALYSISDIR+"acceptance_energy_theta_acceptance_{CAMPAIGN}.png",
entries_canvas=ANALYSISDIR+"acceptance_entries_{CAMPAIGN}.png",
rootfile=ANALYSISDIR+"electron_beamline_acceptanceTestAnalysis{CAMPAIGN}.root",
beampipe_canvas=ANALYSISDIR+"electron_beamline_acceptance_in_beampipe_{CAMPAIGN}.png",
etheta_canvas=ANALYSISDIR+"electron_beamline_acceptance_energy_theta_{CAMPAIGN}.png",
etheta_acceptance_canvas=ANALYSISDIR+"electron_beamline_acceptance_energy_theta_acceptance_{CAMPAIGN}.png",
entries_canvas=ANALYSISDIR+"electron_beamline_acceptance_entries_{CAMPAIGN}.png",
params:
xml=os.getenv("DETECTOR_PATH")+"/epic_ip6_extended.xml",
shell:
Expand All @@ -83,28 +92,34 @@ rule beamline_acceptance_analysis:
"{output.entries_canvas}")'
"""

##########################################################################################
# Combine results
##########################################################################################
rule beamline:
input:
ANALYSISDIR+"beamlineTestAnalysis{CAMPAIGN}.root",
ANALYSISDIR+"beamspot_{CAMPAIGN}.png",
ANALYSISDIR+"x_px_{CAMPAIGN}.png",
ANALYSISDIR+"y_py_{CAMPAIGN}.png",
ANALYSISDIR+"fitted_position_means_stdevs_{CAMPAIGN}.png",
ANALYSISDIR+"fitted_momentum_means_stdevs_{CAMPAIGN}.png",
ANALYSISDIR+"pipe_parameter_{CAMPAIGN}.png",
ANALYSISDIR+"acceptanceTestAnalysis{CAMPAIGN}.root",
ANALYSISDIR+"acceptance_in_beampipe_{CAMPAIGN}.png",
ANALYSISDIR+"acceptance_energy_theta_{CAMPAIGN}.png",
ANALYSISDIR+"acceptance_energy_theta_acceptance_{CAMPAIGN}.png",
ANALYSISDIR+"acceptance_entries_{CAMPAIGN}.png",
ANALYSISDIR+"electron_beamline_beamlineTestAnalysis{CAMPAIGN}.root",
ANALYSISDIR+"electron_beamline_beamspot_{CAMPAIGN}.png",
ANALYSISDIR+"electron_beamline_x_px_{CAMPAIGN}.png",
ANALYSISDIR+"electron_beamline_y_py_{CAMPAIGN}.png",
ANALYSISDIR+"electron_beamline_fitted_position_means_stdevs_{CAMPAIGN}.png",
ANALYSISDIR+"electron_beamline_fitted_momentum_means_stdevs_{CAMPAIGN}.png",
ANALYSISDIR+"electron_beamline_pipe_parameter_{CAMPAIGN}.png",
ANALYSISDIR+"electron_beamline_acceptanceTestAnalysis{CAMPAIGN}.root",
ANALYSISDIR+"electron_beamline_acceptance_in_beampipe_{CAMPAIGN}.png",
ANALYSISDIR+"electron_beamline_acceptance_energy_theta_{CAMPAIGN}.png",
ANALYSISDIR+"electron_beamline_acceptance_energy_theta_acceptance_{CAMPAIGN}.png",
ANALYSISDIR+"electron_beamline_acceptance_entries_{CAMPAIGN}.png"
output:
directory("results/beamline/{CAMPAIGN}/")
directory("results/beamline/steering_{CAMPAIGN}/")
shell:
"""
mkdir {output}
cp {input} {output}
cp -r {input} {output}
"""

##########################################################################################
# Defualt running
##########################################################################################
rule beamline_local:
input:
"results/beamline/local/"
"results/beamline/steering_local/"
2 changes: 2 additions & 0 deletions benchmarks/beamline/beamlineAnalysis.C
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ int beamlineAnalysis( TString inFile = "/scratch/EIC/G4out/beamline/b
std::cout << "Warning: Only " << h->GetEntries()/nEntries << " of particles contributing to histogram " << name
<< " , which is below the accepted threshold of " << acceptableEntries/nEntries << std::endl;
pass = 1;
} else{
std::cout << "Histogram " << name << " has " << h->GetEntries() << " entries." << std::endl;
}

// Get the pipe radius for this histogram
Expand Down
15 changes: 15 additions & 0 deletions benchmarks/lowq2_reconstruction/LoadData.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import uproot
import awkward as ak

def create_arrays(dataFiles,featureName="_TaggerTrackerFeatureTensor_floatData",targetName="_TaggerTrackerTargetTensor_floatData", entries=None, treeName="events"):

# List of branches to load
branches = [featureName,targetName]

# Load data from concatenated list of files
data = uproot.concatenate([f"{file}:{treeName}" for file in dataFiles], branches, entry_stop=entries, library="ak")

input_data = data[featureName]
target_data = data[targetName]

return input_data, target_data
155 changes: 155 additions & 0 deletions benchmarks/lowq2_reconstruction/RegressionModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class ProjectToX0Plane(nn.Module):
def forward(self, x):
# x shape: (batch, 6) -> [x, y, z, px, py, pz]
x0, y0, z0, px, py, pz = x.unbind(dim=1)

# Normalize momentum components
momentum = torch.sqrt(px**2 + py**2 + pz**2)
px_norm = px / momentum
py_norm = py / momentum
pz_norm = pz / momentum

# Avoid division by zero for px
# eps = 1e-8
# px_safe = torch.where(px_norm.abs() < eps, eps * torch.sign(px_norm) + eps, px_norm)
t = -x0 / px_norm

y_proj = y0 + py_norm * t
z_proj = z0 + pz_norm * t

# Output: [y_proj, z_proj, px_norm, py_norm]
return torch.stack([y_proj, z_proj, px_norm, py_norm], dim=1)

def project_numpy(self, arr):
"""
Projects a numpy array of shape (N, 6) using the forward method,
returns a numpy array of shape (N, 4).
"""
device = next(self.parameters()).device if any(p.device.type != 'cpu' for p in self.parameters()) else 'cpu'
x = torch.from_numpy(arr).float().to(device)
with torch.no_grad():
projected = self.forward(x)
return projected.cpu().numpy()

class RegressionModel(nn.Module):
def __init__(self):
super(RegressionModel, self).__init__()
self.project_to_x0 = ProjectToX0Plane()
self.fc1 = nn.Linear(4, 512)
self.fc2 = nn.Linear(512, 64)
self.fc3 = nn.Linear(64, 3) # Output layer for

# Normalization parameters
self.input_mean = nn.Parameter(torch.zeros(4), requires_grad=False)
self.input_std = nn.Parameter(torch.ones(4), requires_grad=False)
self.output_mean = nn.Parameter(torch.zeros(3), requires_grad=False)
self.output_std = nn.Parameter(torch.ones(3), requires_grad=False)


def forward(self, x):
# Conditionally apply projection
x = self.project_to_x0(x)
# Normalize inputs
x = (x - self.input_mean) / self.input_std

# Pass through the fully connected layers
x = self._core_forward(x)

# Denormalize outputs
x = x * self.output_std + self.output_mean
return x

def _core_forward(self, x):
# Core fully connected layers
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x

def adapt(self, input_data, output_data):
# Normalization
self.input_mean.data = input_data.mean(dim=0)
self.input_std.data = input_data.std(dim=0)
self.output_mean.data = output_data.mean(dim=0)
self.output_std.data = output_data.std(dim=0)

def preprocess_data(model, data_loader, adapt=True):
inputs = data_loader.dataset.tensors[0]
targets = data_loader.dataset.tensors[1]


projected_inputs = model.project_to_x0(inputs)

# Compute normalization parameters
if adapt:
model.adapt(projected_inputs, targets)

# Normalize inputs and targets
normalized_inputs = (projected_inputs - model.input_mean ) / model.input_std
normalized_targets = (targets - model.output_mean) / model.output_std

# Replace the dataset with preprocessed data
data_loader.dataset.tensors = (normalized_inputs, normalized_targets)

def makeModel():
# Create the model
model = RegressionModel()
# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.0004)
# Define the loss function
criterion = nn.HuberLoss(delta=0.2) # Huber loss for regression tasks

return model, optimizer, criterion

def trainModel(epochs, train_loader, val_loader, device):

model, optimizer, criterion = makeModel()

model.to(device)

# Preprocess training and validation data
preprocess_data(model, train_loader, adapt=True)

# Preprocess validation data without adapting
preprocess_data(model, val_loader, adapt=False)

# Move data to the GPU
train_loader.dataset.tensors = (train_loader.dataset.tensors[0].to(device), train_loader.dataset.tensors[1].to(device))
val_loader.dataset.tensors = (val_loader.dataset.tensors[0].to(device), val_loader.dataset.tensors[1].to(device))

# Verify that the model parameters are on the GPU
for name, param in model.named_parameters():
print(f"{name} is on {param.device}")

for epoch in range(epochs):
model.train()
running_loss = 0.0
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = model._core_forward(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)

epoch_loss = running_loss / len(train_loader.dataset)


# Validation step
model.eval()
val_loss = 0.0
with torch.no_grad():
for val_inputs, val_targets in val_loader:
val_outputs = model._core_forward(val_inputs)
val_loss += criterion(val_outputs, val_targets).item() * val_inputs.size(0)

val_loss /= len(val_loader.dataset)

print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss}, Val Loss: {val_loss}")

return model
Loading