Skip to content

Commit 2131b97

Browse files
authored
Rf/visualize training (#106)
* Rf visualize training * Adjust single hover styling * Fix legend * Add beartype annotaiton to visualize_training
1 parent 39b6909 commit 2131b97

5 files changed

Lines changed: 66 additions & 16 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ Please cite with:
201201
title = {sequifier - causal transformer models for multivariate sequence modelling},
202202
year = {2025},
203203
publisher = {GitHub},
204-
version = {v1.1.0.4},
204+
version = {v1.1.0.5},
205205
url = {https://github.com/0xideas/sequifier}
206206
}
207207
```

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
project = 'sequifier'
1616
copyright = '2025, Leon Luithlen'
1717
author = 'Leon Luithlen'
18-
release = 'v1.1.0.4'
18+
release = 'v1.1.0.5'
1919
html_baseurl = 'https://www.sequifier.com/'
2020

2121
# -- General configuration ---------------------------------------------------

documentation/consolidated-docs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ Please cite with:
201201
title = {sequifier - causal transformer models for multivariate sequence modelling},
202202
year = {2025},
203203
publisher = {GitHub},
204-
version = {v1.1.0.4},
204+
version = {v1.1.0.5},
205205
url = {https://github.com/0xideas/sequifier}
206206
}
207207
```

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "sequifier"
7-
version = "v1.1.0.4"
7+
version = "v1.1.0.5"
88
authors = [
99
{ name = "Leon Luithlen", email = "leontimnaluithlen@gmail.com" },
1010
]

src/sequifier/visualize_training.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from typing import Any, Optional
77

88
import numpy as np
9+
import plotly.colors as pc # Added to fetch consistent colors
910
import plotly.graph_objects as go
11+
from beartype import beartype
1012

1113
# Import Loguru and your custom logger config
1214
from loguru import logger
@@ -72,6 +74,7 @@ def __init__(self, model_name: str):
7274
self.expected_num_batches: Optional[int] = None
7375
self.pending_var_loss_epoch: Optional[int] = None
7476

77+
@beartype
7578
def parse_file(self, log_file: str) -> TrainingMetrics:
7679
with open(log_file, "r") as f:
7780
for line_num, line in enumerate(f, 1):
@@ -83,6 +86,7 @@ def parse_file(self, log_file: str) -> TrainingMetrics:
8386
self._validate_final_metrics()
8487
return self.metrics
8588

89+
@beartype
8690
def _process_line(self, line: str) -> None:
8791
"""Routes the line to the appropriate sub-parser based on strict string matching."""
8892
if "[INFO] Validation | Epoch:" in line:
@@ -94,6 +98,7 @@ def _process_line(self, line: str) -> None:
9498
elif "[INFO] Epoch" in line or "[INFO] Validation" in line:
9599
self.pending_var_loss_epoch = None
96100

101+
@beartype
97102
def _process_validation(self, line: str) -> None:
98103
match = VAL_PATTERN.search(line)
99104
if not match:
@@ -115,6 +120,7 @@ def _process_validation(self, line: str) -> None:
115120
self.metrics.baseline_losses[epoch] = baseline
116121
self.pending_var_loss_epoch = epoch
117122

123+
@beartype
118124
def _process_var_loss(self, line: str) -> None:
119125
match = VAR_PATTERN.search(line)
120126
if not match:
@@ -138,6 +144,7 @@ def _process_var_loss(self, line: str) -> None:
138144

139145
self.pending_var_loss_epoch = None
140146

147+
@beartype
141148
def _process_training(self, line: str) -> None:
142149
match = TRAIN_PATTERN.search(line)
143150
if not match:
@@ -165,6 +172,7 @@ def _process_training(self, line: str) -> None:
165172

166173
self.metrics.train_losses[epoch][batch] = (num_batches, loss)
167174

175+
@beartype
168176
def _validate_chronology(self, epoch: int, batch: int, num_batches: int) -> None:
169177
if self.current_epoch is not None and self.current_batch is not None:
170178
if epoch == self.current_epoch and batch <= self.current_batch:
@@ -219,12 +227,14 @@ def _validate_final_metrics(self) -> None:
219227
# -------------------------------------------------------------------------
220228
# Utility Functions
221229
# -------------------------------------------------------------------------
230+
@beartype
222231
def parse_number(val: str) -> float:
223232
"""Strictly parse numbers, explicitly handling the 'NaN' strings."""
224233
val = val.strip()
225234
return np.nan if val == "NaN" else float(val)
226235

227236

237+
@beartype
228238
def parse_args_to_models(args: argparse.Namespace) -> list[str]:
229239
"""Extracts the list of models from a file or comma-separated string."""
230240
if os.path.isfile(args.models) and args.models.endswith(".txt"):
@@ -235,6 +245,7 @@ def parse_args_to_models(args: argparse.Namespace) -> list[str]:
235245
return [m.strip() for m in args.models.split(",") if m.strip()]
236246

237247

248+
@beartype
238249
def get_log_filepath(args: argparse.Namespace, model: str) -> str:
239250
"""Finds the appropriate log file for a given model."""
240251
log_pattern = os.path.join(
@@ -256,6 +267,7 @@ def get_log_filepath(args: argparse.Namespace, model: str) -> str:
256267
return log_files[0]
257268

258269

270+
@beartype
259271
def format_plot_data(
260272
metrics: TrainingMetrics, bucket_batches: Optional[int], model: str
261273
) -> dict[str, Any]:
@@ -319,6 +331,7 @@ def format_plot_data(
319331
# -------------------------------------------------------------------------
320332
# Plotting & Reporting
321333
# -------------------------------------------------------------------------
334+
@beartype
322335
def _generate_single_model_plot(
323336
model: str, data: dict[str, Any], yaxis_type: str, out_path: str
324337
) -> None:
@@ -334,14 +347,22 @@ def _generate_single_model_plot(
334347

335348
fig.add_trace(
336349
go.Scatter(
337-
x=data["val_x"], y=data["val_y"], mode="lines", name="Validation Loss"
350+
x=data["val_x"],
351+
y=data["val_y"],
352+
mode="lines",
353+
name="Validation Loss",
354+
hovertemplate=f"<b>{model}</b><br>Val Loss: %{{y}}<br>Epoch: %{{x}}<extra></extra>",
338355
),
339356
row=1,
340357
col=1,
341358
)
342359
fig.add_trace(
343360
go.Scatter(
344-
x=data["train_x"], y=data["train_y"], mode="lines", name="Training Loss"
361+
x=data["train_x"],
362+
y=data["train_y"],
363+
mode="lines",
364+
name="Training Loss",
365+
hovertemplate=f"<b>{model}</b><br>Train Loss: %{{y}}<br>Epoch: %{{x}}<extra></extra>",
345366
),
346367
row=1,
347368
col=1,
@@ -355,6 +376,7 @@ def _generate_single_model_plot(
355376
mode="lines",
356377
name="Baseline Loss",
357378
line=dict(dash="dash"),
379+
hovertemplate=f"<b>{model}</b><br>Baseline Loss: %{{y}}<br>Epoch: %{{x}}<extra></extra>",
358380
),
359381
row=1,
360382
col=1,
@@ -376,7 +398,15 @@ def _generate_single_model_plot(
376398
for e in epochs
377399
]
378400
fig.add_trace(
379-
go.Scatter(x=epochs, y=y_norm, mode="lines", name=var), row=1, col=2
401+
go.Scatter(
402+
x=epochs,
403+
y=y_norm,
404+
mode="lines",
405+
name=var,
406+
hovertemplate=f"<b>{var}</b>: %{{y}}<br>Epoch: %{{x}}<extra></extra>",
407+
),
408+
row=1,
409+
col=2,
380410
)
381411

382412
fig.update_xaxes(title_text="Epoch", dtick=1, row=1, col=2)
@@ -393,6 +423,7 @@ def _generate_single_model_plot(
393423
logger.info(f"Visualization HTML generated and saved successfully to {out_path}")
394424

395425

426+
@beartype
396427
def _generate_multi_model_plot(
397428
models: list[str], all_data: dict[str, Any], yaxis_type: str, out_path: str
398429
) -> None:
@@ -401,22 +432,39 @@ def _generate_multi_model_plot(
401432
rows=1, cols=2, subplot_titles=("Validation Losses", "Training Losses")
402433
)
403434
baseline_val = None
435+
colors = pc.qualitative.Plotly # Load Plotly's default distinct color array
404436

405-
for model in models:
437+
for i, model in enumerate(models):
406438
data = all_data[model]
439+
color = colors[i % len(colors)] # Cycle colors if models exceed palette limit
440+
441+
# Validation trace
407442
fig.add_trace(
408443
go.Scatter(
409-
x=data["val_x"], y=data["val_y"], mode="lines", name=f"{model} Val Loss"
444+
x=data["val_x"],
445+
y=data["val_y"],
446+
mode="lines",
447+
name=model,
448+
legendgroup=model,
449+
line=dict(color=color),
450+
showlegend=True, # Only show validation trace in legend to prevent duplicates
451+
hovertemplate=f"<b>{model}</b><br>Val Loss: %{{y}}<br>Epoch: %{{x}}<extra></extra>",
410452
),
411453
row=1,
412454
col=1,
413455
)
456+
457+
# Training trace
414458
fig.add_trace(
415459
go.Scatter(
416460
x=data["train_x"],
417461
y=data["train_y"],
418462
mode="lines",
419-
name=f"{model} Train Loss",
463+
name=model,
464+
legendgroup=model,
465+
line=dict(color=color),
466+
showlegend=False, # Hidden from legend, but linked via legendgroup
467+
hovertemplate=f"<b>{model}</b><br>Train Loss: %{{y}}<br>Epoch: %{{x}}<extra></extra>",
420468
),
421469
row=1,
422470
col=2,
@@ -433,20 +481,20 @@ def _generate_multi_model_plot(
433481
)
434482

435483
if baseline_val is not None:
436-
max_x = max(
437-
[max(all_data[m]["train_x"]) for m in models if all_data[m]["train_x"]]
438-
+ [0]
484+
# Plot baseline on the Validation subplot (col=1)
485+
max_val_x = max(
486+
[max(all_data[m]["val_x"]) for m in models if all_data[m]["val_x"]] + [0]
439487
)
440488
fig.add_trace(
441489
go.Scatter(
442-
x=[0, max_x],
490+
x=[0, max_val_x],
443491
y=[baseline_val, baseline_val],
444492
mode="lines",
445493
name="Baseline Loss",
446-
line=dict(dash="dash"),
494+
line=dict(dash="dash", color="black"),
447495
),
448496
row=1,
449-
col=2,
497+
col=1,
450498
)
451499

452500
fig.update_xaxes(title_text="Epoch", dtick=1, row=1, col=1)
@@ -459,6 +507,7 @@ def _generate_multi_model_plot(
459507
logger.info(f"Visualization HTML generated and saved successfully to {out_path}")
460508

461509

510+
@beartype
462511
def generate_html_report(
463512
all_data: dict[str, Any], models: list[str], args: argparse.Namespace
464513
) -> None:
@@ -480,6 +529,7 @@ def generate_html_report(
480529
# -------------------------------------------------------------------------
481530
# Orchestrator
482531
# -------------------------------------------------------------------------
532+
@beartype
483533
def visualize_training(args: argparse.Namespace) -> None:
484534
"""Main orchestrator function."""
485535
models = parse_args_to_models(args)

0 commit comments

Comments
 (0)