This repository was archived by the owner on Jul 26, 2025. It is now read-only.
forked from tomgoldstein/loss-landscape
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnet_plotter.py
More file actions
executable file
·158 lines (131 loc) · 5.89 KB
/
net_plotter.py
File metadata and controls
executable file
·158 lines (131 loc) · 5.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Manipulate network parameters and setup random directions with normalization.
"""
import torch
import copy
################################################################################
# Supporting functions for weights manipulation
################################################################################
def get_weights(net):
""" Extract parameters from net, and return a list of tensors"""
return [p.data for p in net.parameters()]
def set_weights(net, weights, directions=None, step=None):
"""
Overwrite the network's weights with a specified list of tensors
or change weights along directions with a step size.
"""
if directions is None:
# You cannot specify a step length without a direction.
for (p, w) in zip(net.parameters(), weights):
p.data.copy_(w.type(type(p.data)))
else:
assert step is not None, 'If a direction is specified then step must be specified as well'
if len(directions) == 2:
dx = directions[0]
dy = directions[1]
changes = [d0*step[0] + d1*step[1] for (d0, d1) in zip(dx, dy)]
else:
changes = [d*step for d in directions[0]]
for (p, w, d) in zip(net.parameters(), weights, changes):
p.data = w + torch.Tensor(d).type(type(w))
def set_states(net, states, directions=None, step=None):
"""
Overwrite the network's state_dict or change it along directions with a step size.
"""
if directions is None:
net.load_state_dict(states)
else:
assert step is not None, 'If direction is provided then the step must be specified as well'
if len(directions) == 2:
dx = directions[0]
dy = directions[1]
changes = [d0*step[0] + d1*step[1] for (d0, d1) in zip(dx, dy)]
else:
changes = [d*step for d in directions[0]]
new_states = copy.deepcopy(states)
assert (len(new_states) == len(changes))
for (k, v), d in zip(new_states.items(), changes):
d = torch.tensor(d)
v.add_(d.type(v.type()))
net.load_state_dict(new_states)
def get_random_weights(weights):
"""
Produce a random direction that is a list of random Gaussian tensors
with the same shape as the network's weights, so one direction entry per weight.
"""
return [torch.randn(w.size()) for w in weights]
def get_random_states(states):
"""
Produce a random direction that is a list of random Gaussian tensors
with the same shape as the network's state_dict(), so one direction entry
per weight, including BN's running_mean/var.
"""
return [torch.randn(w.size()) for k, w in states.items()]
def get_diff_weights(weights, weights2):
""" Produce a direction from 'weights' to 'weights2'."""
return [w2 - w for (w, w2) in zip(weights, weights2)]
def get_diff_states(states, states2):
""" Produce a direction from 'states' to 'states2'."""
return [v2 - v for (k, v), (k2, v2) in zip(states.items(), states2.items())]
################################################################################
# Normalization Functions
################################################################################
def normalize_direction(direction, weights, norm='filter'):
"""
Rescale the direction so that it has similar norm as their corresponding
model in different levels.
Args:
direction: a variables of the random direction for one layer
weights: a variable of the original model for one layer
norm: normalization method, 'filter' | 'layer' | 'weight'
"""
if norm == 'filter':
# Rescale the filters (weights in group) in 'direction' so that each
# filter has the same norm as its corresponding filter in 'weights'.
for d, w in zip(direction, weights):
d.mul_(w.norm()/(d.norm() + 1e-10))
elif norm == 'layer':
# Rescale the layer variables in the direction so that each layer has
# the same norm as the layer variables in weights.
direction.mul_(weights.norm()/direction.norm())
elif norm == 'weight':
# Rescale the entries in the direction so that each entry has the same
# scale as the corresponding weight.
direction.mul_(weights)
elif norm == 'dfilter':
# Rescale the entries in the direction so that each filter direction
# has the unit norm.
for d in direction:
d.div_(d.norm() + 1e-10)
elif norm == 'dlayer':
# Rescale the entries in the direction so that each layer direction has
# the unit norm.
direction.div_(direction.norm())
def normalize_directions_for_weights(direction, weights, norm='filter', ignore='biasbn'):
"""
The normalization scales the direction entries according to the entries of weights.
"""
assert(len(direction) == len(weights))
for d, w in zip(direction, weights):
if d.dim() <= 1:
if ignore == 'biasbn':
d.fill_(0) # ignore directions for weights with 1 dimension
else:
d.copy_(w) # keep directions for weights/bias that are only 1 per node
else:
normalize_direction(d, w, norm)
def normalize_directions_for_states(direction, states, norm='filter', ignore='ignore'):
assert(len(direction) == len(states))
for d, (k, w) in zip(direction, states.items()):
if d.dim() <= 1:
if ignore == 'biasbn':
d.fill_(0) # ignore directions for weights with 1 dimension
else:
d.copy_(w) # keep directions for weights/bias that are only 1 per node
else:
normalize_direction(d, w, norm)
def ignore_biasbn(directions):
""" Set bias and bn parameters in directions to zero """
for d in directions:
if d.dim() <= 1:
d.fill_(0)