Skip to content

Commit a7637d8

Browse files
committed
added script to compute Empirical Fisher
1 parent 6673ae7 commit a7637d8

File tree

3 files changed

+69
-2
lines changed

3 files changed

+69
-2
lines changed

neuralmonkey/decoders/transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ def __init__(self,
214214
self._variable_scope.set_initializer(tf.variance_scaling_initializer(
215215
mode="fan_avg", distribution="uniform"))
216216

217+
if reuse:
218+
self._variable_scope.reuse_variables()
217219
log("Decoder cost op: {}".format(self.cost))
218220
self._variable_scope.reuse_variables()
219221
log("Runtime logits: {}".format(self.runtime_logits))

scripts/compute_fisher.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#!/usr/bin/env python3
2+
"""Compute the Empirical Fisher matrix using a list of gradients.
3+
4+
The gradient tensors can be spread over multiple npz files. The mean
5+
is computed over the first dimension (supposed to be a batch).
6+
7+
"""
8+
9+
import argparse
10+
import os
11+
import re
12+
import glob
13+
14+
import numpy as np
15+
16+
from neuralmonkey.logging import log as _log
17+
18+
19+
def log(message: str, color: str = "blue") -> None:
20+
_log(message, color)
21+
22+
23+
def main() -> None:
24+
parser = argparse.ArgumentParser(description=__doc__)
25+
parser.add_argument("--file_prefix", type=str,
26+
help="prefix of the npz files containing the gradients")
27+
parser.add_argument("--output_path", type=str,
28+
help="Path to output the Empirical Fisher to.")
29+
args = parser.parse_args()
30+
31+
output_dict = {}
32+
n = 0
33+
for file in glob.glob("{}.*npz".format(args.file_prefix)):
34+
log("Processing {}".format(file))
35+
tensors = np.load(file)
36+
37+
# first dimension must be equal for all tensors (batch)
38+
shapes = [tensors[f].shape for f in tensors.files]
39+
assert all([x[0] == shapes[0][0] for x in shapes])
40+
41+
for varname in tensors.files:
42+
res = np.sum(np.square(tensors[varname]), 0)
43+
if varname in output_dict:
44+
output_dict[varname] += res
45+
else:
46+
output_dict[varname] = res
47+
n += shapes[0][0]
48+
49+
for name in output_dict:
50+
output_dict[name] /= n
51+
52+
np.savez(args.output_path, **output_dict)
53+
54+
55+
if __name__ == "__main__":
56+
main()

tests/vocab.ini

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ tf_manager=<tf_manager>
44
output="tests/outputs/vocab"
55
overwrite_output_dir=True
66
batch_size=16
7-
epochs=0
7+
epochs=1
88
train_dataset=<train_data>
99
val_dataset=<val_data>
1010
trainer=<trainer>
@@ -66,11 +66,20 @@ dropout_keep_prob=0.5
6666
data_id="target"
6767
vocabulary=<decoder_vocabulary>
6868

69-
[trainer]
69+
[trainer1]
7070
class=trainers.cross_entropy_trainer.CrossEntropyTrainer
7171
decoders=[<decoder>]
7272
regularizers=[<train_l2>]
7373

74+
[trainer2]
75+
class=trainers.cross_entropy_trainer.CrossEntropyTrainer
76+
decoders=[<decoder>]
77+
regularizers=[<train_l2>]
78+
79+
[trainer]
80+
class=trainers.multitask_trainer.MultitaskTrainer
81+
trainers=[<trainer1>, <trainer2>]
82+
7483
[train_l2]
7584
class=trainers.regularizers.L2Regularizer
7685
weight=1.0e-8

0 commit comments

Comments
 (0)