22# Datetime: 2021/2/23
33# Copyright belongs to the author.
44# Please indicate the source for reprinting.
5- import copy
65
76import paddle
87
@@ -40,6 +39,7 @@ def __init__(self,
4039 weight_initializer = None ):
4140 """
4241 :param layers: 需要重置的Layer 或 Layer列表
42+ :type layers: (paddle.nn.Layer|list)
4343 :param re_init_epoch: 经历多少EPOCH后重新初始化输出层
4444 :param max_re_num: Layer最大重置次数
4545 :param weight_initializer: 权重默认初始化方案(若为None则为原始权重,可为paddle.nn.initializer.XavierNormal())
@@ -58,41 +58,27 @@ def apply(self, current_epoch: int):
5858 应用RIFLE
5959 :param current_epoch: 当前遍历过的EPOCH数量
6060 """
61- if current_epoch % self .re_init_epoch == 0 and (current_epoch // self .re_init_epoch ) <= self .max_re_num :
62- print_str = f"Initialization successful, { len (self .layers )} group layers will apply RIFLE"
63-
64- if current_epoch == 0 and self .weight_initializer is None :
65- for layer_id , layer in enumerate (self .layers ):
66- for param_id , param in enumerate (layer .parameters ()):
67- if ".w_" in param .name :
68- self .CACHE_PARAMS [f"w_{ layer_id } _{ param_id } " ] = layer .weight .numpy ()
69- elif ".b_" in param .name :
70- self .CACHE_PARAMS [f"b_{ layer_id } _{ param_id } " ] = layer .bias .numpy ()
71- else :
72- for layer_id , layer in enumerate (self .layers ):
61+ if current_epoch == 0 :
62+ print (f"\033 [0;37;41mInitialization successful, { len (self .layers )} group layers will apply RIFLE\033 [0m" )
63+ elif current_epoch % self .re_init_epoch == 0 and (current_epoch // self .re_init_epoch ) <= self .max_re_num :
64+ for layer_id , layer in enumerate (self .layers ):
65+ if self .weight_initializer is None :
66+ layer .parameters ().clear ()
67+ else :
7368 for param_id , param in enumerate (layer .parameters ()):
7469 if ".w_" in param .name :
75- if self .weight_initializer is not None :
76- layer .weight = layer .create_parameter (shape = layer .weight .shape ,
77- attr = None ,
78- dtype = layer .weight .dtype ,
79- is_bias = False ,
80- default_initializer = self .weight_initializer )
81- else :
82- layer .weight .set_value (self .CACHE_PARAMS [f"w_{ layer_id } _{ param_id } " ])
83-
70+ layer .weight = layer .create_parameter (shape = layer .weight .shape ,
71+ attr = None ,
72+ dtype = layer .weight .dtype ,
73+ is_bias = False ,
74+ default_initializer = self .weight_initializer )
8475 elif ".b_" in param .name :
85- if self .weight_initializer is not None :
86- layer .bias = layer .create_parameter (shape = layer .bias .shape ,
87- attr = None ,
88- dtype = layer .bias .dtype ,
89- is_bias = True )
90- else :
91- layer .bias .set_value (self .CACHE_PARAMS [f"b_{ layer_id } _{ param_id } " ])
92-
93- print_str = f"RIFLE: layer has been reset in the { current_epoch } epoch!"
76+ layer .bias = layer .create_parameter (shape = layer .bias .shape ,
77+ attr = None ,
78+ dtype = layer .bias .dtype ,
79+ is_bias = True )
9480
95- print (f"\033 [0;37;41m { print_str } \033 [0m" )
81+ print (f"\033 [0;37;41mRIFLE: layer has been reset in the { current_epoch } epoch! \033 [0m" )
9682
9783
9884class RIFLECallback (paddle .callbacks .Callback ):
@@ -126,6 +112,7 @@ def __init__(self,
126112 """
127113 RIFLE的CallBack实现
128114 :param layers: 需要进行RIFLE的输出层
115+ :type layers: (paddle.nn.Layer|list)
129116 :param re_init_epoch: 经历多少EPOCH后重新初始化输出层
130117 :param max_re_num: Layer最大重置次数
131118 :param weight_initializer: 权重默认初始化方案(若为None则为原始权重,可为paddle.nn.initializer.XavierNormal())
0 commit comments