-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
44 lines (31 loc) · 1.14 KB
/
utils.py
File metadata and controls
44 lines (31 loc) · 1.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""Shared pydantic settings configuration."""
import json
from pathlib import Path
from typing import Union
import matplotlib.pyplot as plt
from pydantic import BaseSettings as PydanticBaseSettings
class BaseSettings(PydanticBaseSettings):
"""Add configuration to default Pydantic BaseSettings."""
class Config:
"""Configure BaseSettings behavior."""
extra = "forbid"
use_enum_values = True
env_prefix = "jv_"
def plot_learning_curve(
results_dir: Union[str, Path], key: str = "mae", plot_train: bool = False
):
"""Plot learning curves based on json history files."""
if isinstance(results_dir, str):
results_dir = Path(results_dir)
with open(results_dir / "history_val.json", "r") as f:
val = json.load(f)
p = plt.plot(val[key], label=results_dir.name)
if plot_train:
# plot the training trace in the same color, lower opacity
with open(results_dir / "history_train.json", "r") as f:
train = json.load(f)
c = p[0].get_color()
plt.plot(train[key], alpha=0.5, c=c)
plt.xlabel("epochs")
plt.ylabel(key)
return train, val