66from typing import Any , Optional
77
88import numpy as np
9+ import plotly .colors as pc # Added to fetch consistent colors
910import plotly .graph_objects as go
11+ from beartype import beartype
1012
1113# Import Loguru and your custom logger config
1214from loguru import logger
@@ -72,6 +74,7 @@ def __init__(self, model_name: str):
7274 self .expected_num_batches : Optional [int ] = None
7375 self .pending_var_loss_epoch : Optional [int ] = None
7476
77+ @beartype
7578 def parse_file (self , log_file : str ) -> TrainingMetrics :
7679 with open (log_file , "r" ) as f :
7780 for line_num , line in enumerate (f , 1 ):
@@ -83,6 +86,7 @@ def parse_file(self, log_file: str) -> TrainingMetrics:
8386 self ._validate_final_metrics ()
8487 return self .metrics
8588
89+ @beartype
8690 def _process_line (self , line : str ) -> None :
8791 """Routes the line to the appropriate sub-parser based on strict string matching."""
8892 if "[INFO] Validation | Epoch:" in line :
@@ -94,6 +98,7 @@ def _process_line(self, line: str) -> None:
9498 elif "[INFO] Epoch" in line or "[INFO] Validation" in line :
9599 self .pending_var_loss_epoch = None
96100
101+ @beartype
97102 def _process_validation (self , line : str ) -> None :
98103 match = VAL_PATTERN .search (line )
99104 if not match :
@@ -115,6 +120,7 @@ def _process_validation(self, line: str) -> None:
115120 self .metrics .baseline_losses [epoch ] = baseline
116121 self .pending_var_loss_epoch = epoch
117122
123+ @beartype
118124 def _process_var_loss (self , line : str ) -> None :
119125 match = VAR_PATTERN .search (line )
120126 if not match :
@@ -138,6 +144,7 @@ def _process_var_loss(self, line: str) -> None:
138144
139145 self .pending_var_loss_epoch = None
140146
147+ @beartype
141148 def _process_training (self , line : str ) -> None :
142149 match = TRAIN_PATTERN .search (line )
143150 if not match :
@@ -165,6 +172,7 @@ def _process_training(self, line: str) -> None:
165172
166173 self .metrics .train_losses [epoch ][batch ] = (num_batches , loss )
167174
175+ @beartype
168176 def _validate_chronology (self , epoch : int , batch : int , num_batches : int ) -> None :
169177 if self .current_epoch is not None and self .current_batch is not None :
170178 if epoch == self .current_epoch and batch <= self .current_batch :
@@ -219,12 +227,14 @@ def _validate_final_metrics(self) -> None:
219227# -------------------------------------------------------------------------
220228# Utility Functions
221229# -------------------------------------------------------------------------
230+ @beartype
222231def parse_number (val : str ) -> float :
223232 """Strictly parse numbers, explicitly handling the 'NaN' strings."""
224233 val = val .strip ()
225234 return np .nan if val == "NaN" else float (val )
226235
227236
237+ @beartype
228238def parse_args_to_models (args : argparse .Namespace ) -> list [str ]:
229239 """Extracts the list of models from a file or comma-separated string."""
230240 if os .path .isfile (args .models ) and args .models .endswith (".txt" ):
@@ -235,6 +245,7 @@ def parse_args_to_models(args: argparse.Namespace) -> list[str]:
235245 return [m .strip () for m in args .models .split ("," ) if m .strip ()]
236246
237247
248+ @beartype
238249def get_log_filepath (args : argparse .Namespace , model : str ) -> str :
239250 """Finds the appropriate log file for a given model."""
240251 log_pattern = os .path .join (
@@ -256,6 +267,7 @@ def get_log_filepath(args: argparse.Namespace, model: str) -> str:
256267 return log_files [0 ]
257268
258269
270+ @beartype
259271def format_plot_data (
260272 metrics : TrainingMetrics , bucket_batches : Optional [int ], model : str
261273) -> dict [str , Any ]:
@@ -319,6 +331,7 @@ def format_plot_data(
319331# -------------------------------------------------------------------------
320332# Plotting & Reporting
321333# -------------------------------------------------------------------------
334+ @beartype
322335def _generate_single_model_plot (
323336 model : str , data : dict [str , Any ], yaxis_type : str , out_path : str
324337) -> None :
@@ -334,14 +347,22 @@ def _generate_single_model_plot(
334347
335348 fig .add_trace (
336349 go .Scatter (
337- x = data ["val_x" ], y = data ["val_y" ], mode = "lines" , name = "Validation Loss"
350+ x = data ["val_x" ],
351+ y = data ["val_y" ],
352+ mode = "lines" ,
353+ name = "Validation Loss" ,
354+ hovertemplate = f"<b>{ model } </b><br>Val Loss: %{{y}}<br>Epoch: %{{x}}<extra></extra>" ,
338355 ),
339356 row = 1 ,
340357 col = 1 ,
341358 )
342359 fig .add_trace (
343360 go .Scatter (
344- x = data ["train_x" ], y = data ["train_y" ], mode = "lines" , name = "Training Loss"
361+ x = data ["train_x" ],
362+ y = data ["train_y" ],
363+ mode = "lines" ,
364+ name = "Training Loss" ,
365+ hovertemplate = f"<b>{ model } </b><br>Train Loss: %{{y}}<br>Epoch: %{{x}}<extra></extra>" ,
345366 ),
346367 row = 1 ,
347368 col = 1 ,
@@ -355,6 +376,7 @@ def _generate_single_model_plot(
355376 mode = "lines" ,
356377 name = "Baseline Loss" ,
357378 line = dict (dash = "dash" ),
379+ hovertemplate = f"<b>{ model } </b><br>Baseline Loss: %{{y}}<br>Epoch: %{{x}}<extra></extra>" ,
358380 ),
359381 row = 1 ,
360382 col = 1 ,
@@ -376,7 +398,15 @@ def _generate_single_model_plot(
376398 for e in epochs
377399 ]
378400 fig .add_trace (
379- go .Scatter (x = epochs , y = y_norm , mode = "lines" , name = var ), row = 1 , col = 2
401+ go .Scatter (
402+ x = epochs ,
403+ y = y_norm ,
404+ mode = "lines" ,
405+ name = var ,
406+ hovertemplate = f"<b>{ var } </b>: %{{y}}<br>Epoch: %{{x}}<extra></extra>" ,
407+ ),
408+ row = 1 ,
409+ col = 2 ,
380410 )
381411
382412 fig .update_xaxes (title_text = "Epoch" , dtick = 1 , row = 1 , col = 2 )
@@ -393,6 +423,7 @@ def _generate_single_model_plot(
393423 logger .info (f"Visualization HTML generated and saved successfully to { out_path } " )
394424
395425
426+ @beartype
396427def _generate_multi_model_plot (
397428 models : list [str ], all_data : dict [str , Any ], yaxis_type : str , out_path : str
398429) -> None :
@@ -401,22 +432,39 @@ def _generate_multi_model_plot(
401432 rows = 1 , cols = 2 , subplot_titles = ("Validation Losses" , "Training Losses" )
402433 )
403434 baseline_val = None
435+ colors = pc .qualitative .Plotly # Load Plotly's default distinct color array
404436
405- for model in models :
437+ for i , model in enumerate ( models ) :
406438 data = all_data [model ]
439+ color = colors [i % len (colors )] # Cycle colors if models exceed palette limit
440+
441+ # Validation trace
407442 fig .add_trace (
408443 go .Scatter (
409- x = data ["val_x" ], y = data ["val_y" ], mode = "lines" , name = f"{ model } Val Loss"
444+ x = data ["val_x" ],
445+ y = data ["val_y" ],
446+ mode = "lines" ,
447+ name = model ,
448+ legendgroup = model ,
449+ line = dict (color = color ),
450+ showlegend = True , # Only show validation trace in legend to prevent duplicates
451+ hovertemplate = f"<b>{ model } </b><br>Val Loss: %{{y}}<br>Epoch: %{{x}}<extra></extra>" ,
410452 ),
411453 row = 1 ,
412454 col = 1 ,
413455 )
456+
457+ # Training trace
414458 fig .add_trace (
415459 go .Scatter (
416460 x = data ["train_x" ],
417461 y = data ["train_y" ],
418462 mode = "lines" ,
419- name = f"{ model } Train Loss" ,
463+ name = model ,
464+ legendgroup = model ,
465+ line = dict (color = color ),
466+ showlegend = False , # Hidden from legend, but linked via legendgroup
467+ hovertemplate = f"<b>{ model } </b><br>Train Loss: %{{y}}<br>Epoch: %{{x}}<extra></extra>" ,
420468 ),
421469 row = 1 ,
422470 col = 2 ,
@@ -433,20 +481,20 @@ def _generate_multi_model_plot(
433481 )
434482
435483 if baseline_val is not None :
436- max_x = max (
437- [ max (all_data [ m ][ "train_x" ]) for m in models if all_data [ m ][ "train_x" ]]
438- + [0 ]
484+ # Plot baseline on the Validation subplot (col=1)
485+ max_val_x = max (
486+ [ max ( all_data [ m ][ "val_x" ]) for m in models if all_data [ m ][ "val_x" ]] + [0 ]
439487 )
440488 fig .add_trace (
441489 go .Scatter (
442- x = [0 , max_x ],
490+ x = [0 , max_val_x ],
443491 y = [baseline_val , baseline_val ],
444492 mode = "lines" ,
445493 name = "Baseline Loss" ,
446- line = dict (dash = "dash" ),
494+ line = dict (dash = "dash" , color = "black" ),
447495 ),
448496 row = 1 ,
449- col = 2 ,
497+ col = 1 ,
450498 )
451499
452500 fig .update_xaxes (title_text = "Epoch" , dtick = 1 , row = 1 , col = 1 )
@@ -459,6 +507,7 @@ def _generate_multi_model_plot(
459507 logger .info (f"Visualization HTML generated and saved successfully to { out_path } " )
460508
461509
510+ @beartype
462511def generate_html_report (
463512 all_data : dict [str , Any ], models : list [str ], args : argparse .Namespace
464513) -> None :
@@ -480,6 +529,7 @@ def generate_html_report(
480529# -------------------------------------------------------------------------
481530# Orchestrator
482531# -------------------------------------------------------------------------
532+ @beartype
483533def visualize_training (args : argparse .Namespace ) -> None :
484534 """Main orchestrator function."""
485535 models = parse_args_to_models (args )
0 commit comments