|
| 1 | +import torch.fx as fx |
| 2 | +from typing import Type, Dict, Any, Tuple, Iterable |
| 3 | +import torch |
| 4 | +import copy |
| 5 | +from torch.fx import symbolic_trace |
| 6 | +import time |
| 7 | + |
| 8 | +def _parent_name(target : str) -> Tuple[str, str]: |
| 9 | + """ |
| 10 | + Splits a qualname into parent path and last atom. |
| 11 | + For example, `foo.bar.baz` -> (`foo.bar`, `baz`) |
| 12 | + """ |
| 13 | + *parent, name = target.rsplit('.', 1) |
| 14 | + return parent[0] if parent else '', name |
| 15 | + |
| 16 | +# Works for length 2 patterns with 2 modules |
| 17 | +def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]): |
| 18 | + if len(node.args) == 0: |
| 19 | + return False |
| 20 | + nodes: Tuple[Any, fx.Node] = (node.args[0], node) |
| 21 | + for expected_type, current_node in zip(pattern, nodes): |
| 22 | + if not isinstance(current_node, fx.Node): |
| 23 | + return False |
| 24 | + if current_node.op != 'call_module': |
| 25 | + return False |
| 26 | + if not isinstance(current_node.target, str): |
| 27 | + return False |
| 28 | + if current_node.target not in modules: |
| 29 | + return False |
| 30 | + if type(modules[current_node.target]) is not expected_type: |
| 31 | + return False |
| 32 | + return True |
| 33 | + |
| 34 | + |
| 35 | +def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): |
| 36 | + assert(isinstance(node.target, str)) |
| 37 | + parent_name, name = _parent_name(node.target) |
| 38 | + setattr(modules[parent_name], name, new_module) |
| 39 | + |
| 40 | +def computeUpdatedConvWeightAndBias( |
| 41 | + bn_rv, |
| 42 | + bn_eps, |
| 43 | + bn_w, |
| 44 | + bn_b, |
| 45 | + bn_rm, |
| 46 | + conv_w, |
| 47 | + conv_b=None): |
| 48 | + orig_dtype = bn_rv.dtype |
| 49 | + bn_var_rsqrt = (bn_w / torch.sqrt(bn_rv.to(torch.double) + bn_eps)) |
| 50 | + new_w = (conv_w * (bn_var_rsqrt).reshape(-1, 1, 1, 1)).to(orig_dtype) |
| 51 | + if conv_b is None: |
| 52 | + return new_w |
| 53 | + new_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b |
| 54 | + return new_w, new_b |
| 55 | + |
| 56 | +def fuse_conv_bn_eval(conv, bn): |
| 57 | + assert(not (conv.training or bn.training)), "Fusion only for eval!" |
| 58 | + fused_conv = copy.deepcopy(conv) |
| 59 | + fused_conv.bias = None |
| 60 | + |
| 61 | + fused_conv.weight = \ |
| 62 | + torch.nn.Parameter(computeUpdatedConvWeightAndBias(bn.running_var, bn.eps, bn.weight, bn.bias, bn.running_mean, fused_conv.weight)) |
| 63 | + |
| 64 | + return fused_conv |
| 65 | + |
| 66 | +def fuse_conv_bn(model: torch.nn.Module, inplace=False) -> torch.nn.Module: |
| 67 | + """ |
| 68 | + Fuses convolution/BN layers for inference purposes. Will deepcopy your |
| 69 | + model by default, but can modify the model inplace as well. |
| 70 | + """ |
| 71 | + patterns = [(torch.nn.Conv2d, torch.nn.BatchNorm2d)] |
| 72 | + if not inplace: |
| 73 | + model = copy.deepcopy(model) |
| 74 | + fx_model = fx.symbolic_trace(model) |
| 75 | + modules = dict(fx_model.named_modules()) |
| 76 | + new_graph = copy.deepcopy(fx_model.graph) |
| 77 | + |
| 78 | + for pattern in patterns: |
| 79 | + for node in new_graph.nodes: |
| 80 | + if matches_module_pattern(pattern, node, modules): |
| 81 | + if len(node.args[0].users) > 1: # Output of conv is used by other nodes |
| 82 | + continue |
| 83 | + conv = modules[node.args[0].target] |
| 84 | + bn = modules[node.target] |
| 85 | + fused_conv = fuse_conv_bn_eval(conv, bn) |
| 86 | + replace_node_module(node.args[0], modules, fused_conv) |
| 87 | + node.replace_all_uses_with(node.args[0]) |
| 88 | + new_graph.erase_node(node) |
| 89 | + return fx.GraphModule(fx_model, new_graph) |
0 commit comments