-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathweight_comparison.py
More file actions
119 lines (96 loc) · 4.58 KB
/
weight_comparison.py
File metadata and controls
119 lines (96 loc) · 4.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import os
import torch
from typing import Dict
from safetensors.torch import load_file
from rich.console import Console
from rich.table import Table
from transformers import AutoModelForMaskedLM, AutoConfig, AutoModel
from fastplms.e1.modeling_e1 import E1ForMaskedLM, E1Config, E1Model
def load_weights(path: str, cast_fp32: bool = True) -> Dict[str, torch.Tensor]:
assert os.path.exists(path), f"File {path} not found."
if path.endswith(".safetensors"):
sd = load_file(path)
elif path.endswith(".pth") or path.endswith(".pt"):
sd = torch.load(path, map_location="cpu", weights_only=True)
if isinstance(sd, dict) and "state_dict" in sd:
sd = sd["state_dict"]
elif isinstance(sd, dict) and "model" in sd:
sd = sd["model"]
else:
try:
sd = load_file(path)
except Exception:
sd = torch.load(path, map_location="cpu", weights_only=True)
if cast_fp32:
return {k: v.float() if isinstance(v, torch.Tensor) else v for k, v in sd.items()}
return sd
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--file1", type=str, default=None)
parser.add_argument("--files", type=str, nargs="+", default=None)
parser.add_argument("--strict", action="store_true")
parser.add_argument("--assert_exact", action="store_true")
args = parser.parse_args()
model = E1ForMaskedLM.from_pretrained('Profluent-Bio/E1-150m', dtype=torch.float32).eval()
torch.save(model.state_dict(), 'official.pth')
config = AutoConfig.from_pretrained('Synthyra/Profluent-E1-150M', trust_remote_code=True)
model1 = AutoModel.from_pretrained('Synthyra/Profluent-E1-150M', dtype=torch.float32, trust_remote_code=True).eval()
torch.save(model1.state_dict(), 'load_from_pretrained_1.pth')
model2 = AutoModelForMaskedLM.from_pretrained('Synthyra/Profluent-E1-150M', dtype=torch.float32, trust_remote_code=True).eval()
torch.save(model2.state_dict(), 'load_from_pretrained_2.pth')
if args.file1 is None:
args.file1 = 'official.pth'
if args.files is None:
args.files = ['load_from_pretrained_1.pth', 'load_from_pretrained_2.pth', 'old.safetensors']
paths = [args.file1] + args.files
sds = [load_weights(p, cast_fp32=not args.strict) for p in paths]
all_keys = sorted(set().union(*(sd.keys() for sd in sds)))
strict_mismatches = []
console = Console()
table = Table(title=f"Weights Comparison (Reference: {os.path.basename(paths[0])})")
table.add_column("Tensor Name", style="cyan", no_wrap=True)
for p in paths[1:]:
table.add_column(f"{os.path.basename(p)} == Ref", justify="center")
sd1 = sds[0]
for k in all_keys:
row = [k]
has_ref = k in sd1
ref_w = sd1[k] if has_ref else None
for sd in sds[1:]:
has_other = k in sd
other_w = sd[k] if has_other else None
if not has_ref or not has_other:
if not has_ref and not has_other:
row.append("[dim]✔[/dim]")
else:
row.append("[red]✘[/red]")
else:
# Both present, compare shapes and MSE
assert isinstance(ref_w, torch.Tensor), f"Weight {k} in reference is not a tensor."
assert isinstance(other_w, torch.Tensor), f"Weight {k} in comparison file is not a tensor."
if ref_w.shape != other_w.shape:
row.append("[red]✘ (Shape)[/red]")
else:
if args.strict:
if torch.equal(ref_w, other_w):
row.append("[green]✔[/green]")
else:
mse = torch.mean((ref_w.float() - other_w.float())**2).item()
row.append(f"[red]✘ (Strict, MSE: {mse:.2e})[/red]")
strict_mismatches.append(k)
else:
mse = torch.mean((ref_w - other_w)**2).item()
if mse == 0:
row.append("[green]✔[/green]")
else:
row.append(f"[red]✘ (MSE: {mse:.2e})[/red]")
table.add_row(*row)
console.print(table)
if args.strict and args.assert_exact:
assert len(strict_mismatches) == 0, (
f"Found {len(strict_mismatches)} strict mismatches. "
f"First mismatches: {strict_mismatches[:10]}"
)
if __name__ == "__main__":
main()