Skip to content

Commit 85b4cd5

Browse files
committed
1.0rc
1 parent 2b3ec1a commit 85b4cd5

File tree

4 files changed

+64
-55
lines changed

4 files changed

+64
-55
lines changed

dev.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@
1616
# 导入RIFLE模块
1717
from paddle_rifle.rifle import RIFLECallback
1818

19-
transform = Compose([ToTensor()])
20-
21-
train_data = MNIST(transform=transform)
22-
test_data = MNIST(mode="test", transform=transform)
23-
2419

2520
class Net(nn.Layer):
2621
def __init__(self, num_classes=10):
@@ -43,24 +38,37 @@ def forward(self, inputs):
4338
return x
4439

4540

46-
net = Net(num_classes=10)
47-
fc_layer = net.fc2
41+
def main(use_init: bool = False):
42+
transform = Compose([ToTensor()])
43+
44+
train_data = MNIST(transform=transform)
45+
test_data = MNIST(mode="test", transform=transform)
46+
47+
net = Net(num_classes=10)
48+
fc_layer = net.fc2
49+
50+
model = paddle.Model(network=net,
51+
inputs=paddle.static.InputSpec([1, 28, 28], name="ipt"),
52+
labels=paddle.static.InputSpec([1], dtype="int64", name="lab"))
53+
54+
rifle_cb = RIFLECallback(fc_layer,
55+
re_init_epoch=1,
56+
max_re_num=3,
57+
weight_initializer=paddle.nn.initializer.XavierNormal() if use_init else None)
4858

49-
model = paddle.Model(network=net,
50-
inputs=paddle.static.InputSpec([1, 28, 28], name="ipt"),
51-
labels=paddle.static.InputSpec([1], dtype="int64", name="lab"))
59+
sgd = paddle.optimizer.SGD(parameters=model.parameters())
60+
loss = paddle.nn.loss.CrossEntropyLoss()
61+
acc = paddle.metric.Accuracy((1, 5))
62+
model.prepare(sgd, loss, acc)
5263

53-
rifle_cb = RIFLECallback(fc_layer, 1, 3, weight_initializer=paddle.nn.initializer.XavierNormal())
64+
# 开始训练并传入RIFLE Callback
65+
model.fit(train_data,
66+
test_data,
67+
batch_size=256,
68+
epochs=2,
69+
log_freq=10,
70+
callbacks=[rifle_cb])
5471

55-
sgd = paddle.optimizer.SGD(parameters=model.parameters())
56-
loss = paddle.nn.loss.CrossEntropyLoss()
57-
acc = paddle.metric.Accuracy((1, 5))
58-
model.prepare(sgd, loss, acc)
5972

60-
# 开始训练并传入RIFLE Callback
61-
model.fit(train_data,
62-
test_data,
63-
batch_size=32,
64-
epochs=20,
65-
log_freq=100,
66-
callbacks=[rifle_cb])
73+
if __name__ == '__main__':
74+
main()

paddle_rifle/rifle.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Datetime: 2021/2/23
33
# Copyright belongs to the author.
44
# Please indicate the source for reprinting.
5-
import copy
65

76
import paddle
87

@@ -40,6 +39,7 @@ def __init__(self,
4039
weight_initializer=None):
4140
"""
4241
:param layers: 需要重置的Layer 或 Layer列表
42+
:type layers: (paddle.nn.Layer|list)
4343
:param re_init_epoch: 经历多少EPOCH后重新初始化输出层
4444
:param max_re_num: Layer最大重置次数
4545
:param weight_initializer: 权重默认初始化方案(若为None则为原始权重,可为paddle.nn.initializer.XavierNormal())
@@ -58,41 +58,27 @@ def apply(self, current_epoch: int):
5858
应用RIFLE
5959
:param current_epoch: 当前遍历过的EPOCH数量
6060
"""
61-
if current_epoch % self.re_init_epoch == 0 and (current_epoch // self.re_init_epoch) <= self.max_re_num:
62-
print_str = f"Initialization successful, {len(self.layers)} group layers will apply RIFLE"
63-
64-
if current_epoch == 0 and self.weight_initializer is None:
65-
for layer_id, layer in enumerate(self.layers):
66-
for param_id, param in enumerate(layer.parameters()):
67-
if ".w_" in param.name:
68-
self.CACHE_PARAMS[f"w_{layer_id}_{param_id}"] = layer.weight.numpy()
69-
elif ".b_" in param.name:
70-
self.CACHE_PARAMS[f"b_{layer_id}_{param_id}"] = layer.bias.numpy()
71-
else:
72-
for layer_id, layer in enumerate(self.layers):
61+
if current_epoch == 0:
62+
print(f"\033[0;37;41mInitialization successful, {len(self.layers)} group layers will apply RIFLE\033[0m")
63+
elif current_epoch % self.re_init_epoch == 0 and (current_epoch // self.re_init_epoch) <= self.max_re_num:
64+
for layer_id, layer in enumerate(self.layers):
65+
if self.weight_initializer is None:
66+
layer.parameters().clear()
67+
else:
7368
for param_id, param in enumerate(layer.parameters()):
7469
if ".w_" in param.name:
75-
if self.weight_initializer is not None:
76-
layer.weight = layer.create_parameter(shape=layer.weight.shape,
77-
attr=None,
78-
dtype=layer.weight.dtype,
79-
is_bias=False,
80-
default_initializer=self.weight_initializer)
81-
else:
82-
layer.weight.set_value(self.CACHE_PARAMS[f"w_{layer_id}_{param_id}"])
83-
70+
layer.weight = layer.create_parameter(shape=layer.weight.shape,
71+
attr=None,
72+
dtype=layer.weight.dtype,
73+
is_bias=False,
74+
default_initializer=self.weight_initializer)
8475
elif ".b_" in param.name:
85-
if self.weight_initializer is not None:
86-
layer.bias = layer.create_parameter(shape=layer.bias.shape,
87-
attr=None,
88-
dtype=layer.bias.dtype,
89-
is_bias=True)
90-
else:
91-
layer.bias.set_value(self.CACHE_PARAMS[f"b_{layer_id}_{param_id}"])
92-
93-
print_str = f"RIFLE: layer has been reset in the {current_epoch} epoch!"
76+
layer.bias = layer.create_parameter(shape=layer.bias.shape,
77+
attr=None,
78+
dtype=layer.bias.dtype,
79+
is_bias=True)
9480

95-
print(f"\033[0;37;41m{print_str}\033[0m")
81+
print(f"\033[0;37;41mRIFLE: layer has been reset in the {current_epoch} epoch!\033[0m")
9682

9783

9884
class RIFLECallback(paddle.callbacks.Callback):
@@ -126,6 +112,7 @@ def __init__(self,
126112
"""
127113
RIFLE的CallBack实现
128114
:param layers: 需要进行RIFLE的输出层
115+
:type layers: (paddle.nn.Layer|list)
129116
:param re_init_epoch: 经历多少EPOCH后重新初始化输出层
130117
:param max_re_num: Layer最大重置次数
131118
:param weight_initializer: 权重默认初始化方案(若为None则为原始权重,可为paddle.nn.initializer.XavierNormal())

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='paddle-rifle',
5-
version='0.21',
5+
version='1.0rc',
66
packages=["paddle_rifle"],
77
url='https://github.com/GT-ZhangAcer/RIFLE_Module',
88
license='MIT',

unit_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import unittest
2+
from dev import main
3+
4+
5+
class MyTestCase(unittest.TestCase):
6+
def test_none_init(self):
7+
main()
8+
9+
def test_init(self):
10+
main(True)
11+
12+
13+
if __name__ == '__main__':
14+
unittest.main()

0 commit comments

Comments
 (0)