@@ -87,3 +87,163 @@ def fuse_conv_bn(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
8787 node .replace_all_uses_with (node .args [0 ])
8888 new_graph .erase_node (node )
8989 return fx .GraphModule (fx_model , new_graph )
90+
91+ class Conv2dReLU (torch .nn .Module ):
92+ def __init__ (self ,
93+ weight ,
94+ bias ,
95+ stride ,
96+ padding ,
97+ dilation ,
98+ groups ):
99+ super (Conv2dReLU , self ).__init__ ()
100+ self .weight = weight
101+ self .weight_is_channels_last = False
102+ self .bias = bias
103+ self .stride = stride
104+ self .padding = padding
105+ self .dilation = dilation
106+ self .groups = groups
107+ self .slow_fusion = False
108+ if self .weight .size (2 ) == 7 and self .weight .size (3 ) == 7 :
109+ self .slow_fusion = True
110+
111+ def forward (self , inp ):
112+ # NOTE: This will be faster once https://github.com/pytorch/pytorch/pull/62482 lands
113+ if not self .slow_fusion and inp .is_contiguous (memory_format = torch .contiguous_format ):
114+ inp = inp .to (memory_format = torch .channels_last )
115+ if self .slow_fusion and inp .is_contiguous (memory_format = torch .channels_last ):
116+ inp = inp .to (memory_format = torch .contiguous_format )
117+ if not self .slow_fusion and not self .weight_is_channels_last :
118+ self .weight .data = self .weight .to (memory_format = torch .channels_last )
119+ inp = inp .to (memory_format = torch .channels_last )
120+ self .weight_is_channels_last = True
121+ return torch .cudnn_convolution_relu (inp ,
122+ self .weight ,
123+ self .bias ,
124+ self .stride ,
125+ self .padding ,
126+ self .dilation ,
127+ self .groups )
128+
129+ class Conv2dAddReLU (torch .nn .Module ):
130+ def __init__ (self ,
131+ weight ,
132+ bias ,
133+ stride ,
134+ padding ,
135+ dilation ,
136+ groups ):
137+ super (Conv2dAddReLU , self ).__init__ ()
138+ self .weight = weight
139+ self .weight_is_channels_last = False
140+ self .bias = bias
141+ self .stride = stride
142+ self .padding = padding
143+ self .dilation = dilation
144+ self .groups = groups
145+ self .slow_fusion = False
146+ if self .weight .size (2 ) == 7 and self .weight .size (3 ) == 7 :
147+ self .slow_fusion = True
148+
149+ def forward (self , inp , add_input ):
150+ # TODO: Reactivate this once cudnn_convolution_add_relu is fixed.
151+ # weight = self.weight.to(memory_format=torch.contiguous_format)
152+ # if not self.slow_fusion and inp.is_contiguous(memory_format=torch.contiguous_format):
153+ # inp = inp.to(memory_format=torch.channels_last)
154+ # add_input = add_input.to(memory_format=torch.channels_last)
155+ # if self.slow_fusion and inp.is_contiguous(memory_format=torch.channels_last):
156+ # inp = inp.to(memory_format=torch.contiguous_format)
157+ # add_input = add_input.to(memory_format=torch.contiguous_format)
158+ # if not self.slow_fusion and not self.weight_is_channels_last:
159+ # self.weight.data = self.weight.to(memory_format=torch.channels_last)
160+ # inp = inp.to(memory_format=torch.channels_last)
161+ # add_input = add_input.to(memory_format=torch.channels_last)
162+ # self.weight_is_channels_last = True
163+ # return torch.cudnn_convolution_add_relu(inp,
164+ # self.weight,
165+ # add_input,
166+ # 1.0,
167+ # self.bias,
168+ # self.stride,
169+ # self.padding,
170+ # self.dilation,
171+ # self.groups)
172+ out = torch .conv2d (inp ,
173+ self .weight ,
174+ self .bias ,
175+ self .stride ,
176+ self .padding ,
177+ self .dilation ,
178+ self .groups )
179+ out .add_ (add_input )
180+ out .relu_ ()
181+ return out
182+
183+ def fuse_conv_relu (model : torch .nn .Module , inplace = False ) -> torch .nn .Module :
184+ """
185+ Fuses convolution/BN layers for inference purposes. Will deepcopy your
186+ model by default, but can modify the model inplace as well.
187+ """
188+ patterns = [(torch .nn .Conv2d , torch .nn .ReLU )]
189+ if not inplace :
190+ model = copy .deepcopy (model )
191+ fx_model = fx .symbolic_trace (model )
192+ modules = dict (fx_model .named_modules ())
193+ new_graph = copy .deepcopy (fx_model .graph )
194+
195+ for pattern in patterns :
196+ for node in new_graph .nodes :
197+ if matches_module_pattern (pattern , node , modules ):
198+ if len (node .args [0 ].users ) > 1 : # Output of conv is used by other nodes
199+ continue
200+ conv = modules [node .args [0 ].target ]
201+ relu = modules [node .target ]
202+ fused_conv = Conv2dReLU (conv .weight , conv .bias , conv .stride , conv .padding , conv .dilation , conv .groups )
203+ replace_node_module (node .args [0 ], modules , fused_conv )
204+ node .replace_all_uses_with (node .args [0 ])
205+ new_graph .erase_node (node )
206+
207+
208+ last_nodes = []
209+ count = 0
210+ for node in new_graph .nodes :
211+ if count == 31 :
212+ break
213+ if (node .op == "call_function" or node .op == "call_module" ):
214+ last_nodes .append (node )
215+ if len (last_nodes ) == 4 :
216+ last_nodes = last_nodes [1 :]
217+ if len (last_nodes ) < 3 :
218+ continue
219+ is_match = True
220+ is_match = is_match and (last_nodes [0 ].op == "call_module" )
221+ is_match = is_match and (last_nodes [1 ].op == "call_function" )
222+ is_match = is_match and (last_nodes [2 ].op == "call_module" )
223+ is_match = is_match and isinstance (modules [last_nodes [0 ].target ], torch .nn .Conv2d )
224+ is_match = is_match and (str (last_nodes [1 ]).split ("_" )[0 ] == "add" )
225+ is_match = is_match and isinstance (modules [last_nodes [2 ].target ], torch .nn .ReLU )
226+ if (is_match ):
227+ conv = modules [last_nodes [1 ].args [0 ].target ]
228+ fused_conv = Conv2dAddReLU (conv .weight , conv .bias , conv .stride , conv .padding , conv .dilation , conv .groups )
229+ replace_node_module (last_nodes [2 ], modules , fused_conv )
230+ last_nodes [2 ].args = (last_nodes [0 ].args [0 ], last_nodes [1 ].args [1 ])
231+ new_graph .erase_node (last_nodes [1 ])
232+ new_graph .erase_node (last_nodes [0 ])
233+ count += 1
234+ return fx .GraphModule (fx_model , new_graph )
235+
236+
237+ def fuse_conv_add_relu (model : torch .nn .Module , inplace = False ) -> torch .nn .Module :
238+ """
239+ Fuses convolution/BN layers for inference purposes. Will deepcopy your
240+ model by default, but can modify the model inplace as well.
241+ """
242+ if not inplace :
243+ model = copy .deepcopy (model )
244+ fx_model = fx .symbolic_trace (model )
245+ modules = dict (fx_model .named_modules ())
246+ new_graph = copy .deepcopy (fx_model .graph )
247+
248+ new_graph .lint ()
249+ return fx .GraphModule (fx_model , new_graph )
0 commit comments