From 3dcc177af9fce1f55ff247b0bcb6bafecd9e8c2e Mon Sep 17 00:00:00 2001 From: fanyunqian Date: Wed, 6 Jul 2022 09:49:53 +0800 Subject: [PATCH 1/2] [Refactor] split ptq to 2 steps. 1. find the layer node list of the block 2. do reconstruction based on the list --- mqbench/advanced_ptq.py | 73 ++++++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 33 deletions(-) diff --git a/mqbench/advanced_ptq.py b/mqbench/advanced_ptq.py index f9ede726..a1925ae6 100644 --- a/mqbench/advanced_ptq.py +++ b/mqbench/advanced_ptq.py @@ -577,6 +577,9 @@ def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_ enable_quantization(quant_model) torch.cuda.empty_cache() checked_nodes = dict() + + # setup for the reconstruction block node list + block_list = [] for node in nodes: if 'exclude_node_prefix' in config: cont = False @@ -633,41 +636,45 @@ def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_ continue logger.info('the node list is below!') logger.info(layer_node_list) - fp32_module = fp32_modules[qnode2fpnode_dict[layer_node_list[-1]]] - fp32_all_inps = [] - quant_all_inps = [] - fp32_final_oups = None - out_is_cached = False - for _node in layer_node_list: - if all([arg in layer_node_list for arg in _flatten_args(_node.args) if isinstance(arg, torch.fx.Node)]): - continue - else: - fp32_inp_module = fp32_modules[qnode2fpnode_dict[_node]] - quant_module = quant_modules[_node] - # fp32 inps: [out_b1, out_b2, ...] - _, fp32_inps = save_inp_oup_data(fp32_model, None, fp32_inp_module, cali_data, - store_inp=False, store_oup=(config.prob < 1.0), keep_gpu=config.keep_gpu) - _, fp32_oups = save_inp_oup_data(fp32_model, None, fp32_module, cali_data, - store_inp=False, store_oup=(not out_is_cached), keep_gpu=config.keep_gpu) - _, quant_inps = save_inp_oup_data(quant_model, None, quant_module, cali_data, - store_inp=False, store_oup=True, keep_gpu=config.keep_gpu) - fp32_all_inps.append(fp32_inps) - quant_all_inps.append(quant_inps) - if not out_is_cached: - fp32_final_oups = fp32_oups - out_is_cached = True - cached_inps = (quant_all_inps, fp32_all_inps) if config.prob < 1.0 else quant_all_inps - cached_oups = fp32_final_oups - quant_modules_by_name = dict() - for node in layer_node_list: - if node.op == 'call_module': - quant_modules_by_name[node.target] = quant_modules[node] - subgraph = extract_subgraph(quant_modules_by_name, layer_node_list, - layer_node_list[-1], g2node) - logger.info(subgraph.code) - subgraph_reconstruction(subgraph, cached_inps, cached_oups, config) + block_list.append(layer_node_list) for x in layer_node_list: checked_nodes[x] = True + + for layer_node_list in block_list: + fp32_module = fp32_modules[qnode2fpnode_dict[layer_node_list[-1]]] + fp32_all_inps = [] + quant_all_inps = [] + fp32_final_oups = None + out_is_cached = False + for _node in layer_node_list: + if all([arg in layer_node_list for arg in _flatten_args(_node.args) if isinstance(arg, torch.fx.Node)]): + continue + else: + fp32_inp_module = fp32_modules[qnode2fpnode_dict[_node]] + quant_module = quant_modules[_node] + # fp32 inps: [out_b1, out_b2, ...] + _, fp32_inps = save_inp_oup_data(fp32_model, None, fp32_inp_module, cali_data, + store_inp=False, store_oup=(config.prob < 1.0), keep_gpu=config.keep_gpu) + _, fp32_oups = save_inp_oup_data(fp32_model, None, fp32_module, cali_data, + store_inp=False, store_oup=(not out_is_cached), keep_gpu=config.keep_gpu) + _, quant_inps = save_inp_oup_data(quant_model, None, quant_module, cali_data, + store_inp=False, store_oup=True, keep_gpu=config.keep_gpu) + fp32_all_inps.append(fp32_inps) + quant_all_inps.append(quant_inps) + if not out_is_cached: + fp32_final_oups = fp32_oups + out_is_cached = True + cached_inps = (quant_all_inps, fp32_all_inps) if config.prob < 1.0 else quant_all_inps + cached_oups = fp32_final_oups + quant_modules_by_name = dict() + for node in layer_node_list: + if node.op == 'call_module': + quant_modules_by_name[node.target] = quant_modules[node] + subgraph = extract_subgraph(quant_modules_by_name, layer_node_list, + layer_node_list[-1], g2node) + logger.info(subgraph.code) + subgraph_reconstruction(subgraph, cached_inps, cached_oups, config) + disable_all(quant_model) for node in checked_nodes: if node.op == 'call_module': From e3a65b1b8cbee9a37794443c0c44d8b174a2a576 Mon Sep 17 00:00:00 2001 From: fanyunqian Date: Wed, 6 Jul 2022 14:57:16 +0800 Subject: [PATCH 2/2] [Feature] auto add fake module to `call_function` node. --- mqbench/advanced_ptq.py | 154 ++++++++++++++++++++++++++++++---------- 1 file changed, 115 insertions(+), 39 deletions(-) diff --git a/mqbench/advanced_ptq.py b/mqbench/advanced_ptq.py index a1925ae6..405aa872 100644 --- a/mqbench/advanced_ptq.py +++ b/mqbench/advanced_ptq.py @@ -47,6 +47,31 @@ def qnode2fpnode(quant_modules, fp32_modules): qnode2fpnode_dict = {quant_named_nodes[key]: fp32_named_nodes[key] for key in quant_named_nodes} return qnode2fpnode_dict + +def insert_fake_modules(model, node_list, direction): + graph = model.graph + for node in node_list: + inserted_node_target = node.target + '_fake_module_' + direction + setattr(model, inserted_node_target, torch.nn.Identity()) + if direction == 'input': + with graph.inserting_before(node): + inserted_node = graph.create_node(op='call_module', + name=inserted_node_target.replace('.', '_'), + target=inserted_node_target, + args=node.args, + kwargs=node.kwargs) + elif direction == 'output': + with graph.inserting_after(node): + inserted_node = graph.create_node(op='call_module', + name=inserted_node_target.replace('.', '_'), + target=inserted_node_target, + args=(node,), + kwargs={}) + model.recompile() + model.graph.lint() + return model + + def layer_has_weights(nodes, modules): has_weights = False for node in nodes: @@ -233,6 +258,20 @@ def _flatten_args(node): flattned_args.extend([node]) return flattned_args +def get_io_of_block(nodes): + used_list = [] + input_list = [] + for node in nodes: + if all([arg not in nodes for arg in _flatten_args(node.kwargs)]) and all([arg not in nodes for arg in _flatten_args(node.args)]): + input_list.append(node) + for arg in _flatten_args(node.kwargs): + if arg in nodes and arg not in used_list: + used_list.append(arg) + for arg in _flatten_args(node.args): + if arg in nodes and arg not in used_list: + used_list.append(arg) + output_list = [node for node in nodes if node not in used_list] + return input_list, output_list def find_used_times(nodes, target): used = len([_node for _node in target.users if _node in nodes]) @@ -240,7 +279,6 @@ def find_used_times(nodes, target): - def find_cur_node(layer_node_list): node_list = [] used_later = [] @@ -497,44 +535,7 @@ def extract_block(input_nodes, fp32_modules, depth=0): return layer_node_list + exp_nodes + extract_block( [exp_nodes[-1]], fp32_modules, depth + 1) - -def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_module_list: list = None): - r""" - Reconsturction for AdaRound, BRECQ, QDrop. - Basic optimization objective: - - .. math:: - - \mathop{\arg\min}_{\mathbf{V}}\ \ || Wx-\tilde{W}x ||_F^2 + \lambda f_{reg}(\mathbf{V}), - - \tilde{W}=s \cdot clip\left( \left\lfloor\dfrac{W}{s}\right\rfloor+h(\mathbf{V}), n, p \right) - - where :math:`h(\mathbf{V}_{i,j})=clip(\sigma(\mathbf{V}_{i,j})(\zeta-\gamma)+\gamma, 0, 1)`, and :math:`f_{reg}(\mathbf{V})=\mathop{\sum}_{i,j}{1-|2h(\mathbf{V}_{i,j})-1|^\beta}`. By annealing on :math:`\beta`, the rounding mask can adapt freely in initial phase and converge to 0 or 1 in later phase. - - Args: - model (torch.nn.Module): a prepared GraphModule to do PTQ - cali_data (List): a list of calibration tensor - config (dict): a config for PTQ reconstruction - graph_module_list (list): a list of model's children modules which need quantization. if this is used, the model is partial quantized; if not, the model is fully quantized. - - >>> sample config : { - pattern: block (str, Available options are [layer, block].) - scale_lr: 4.0e-5 (learning rate for learning step size of activation) - warm_up: 0.2 (0.2 * max_count iters without regularization to floor or ceil) - weight: 0.01 (loss weight for regularization item) - max_count: 20000 (optimization iteration) - b_range: [20,2] (beta decaying range ) - keep_gpu: True (calibration data restore in gpu or cpu) - round_mode: learned_hard_sigmoid (ways to reconstruct the weight, currently only support learned_hard_sigmoid) - prob: 0.5 (dropping probability of QDROP) - } - - """ - # assert model is on cuda - if not config.keep_gpu: - cali_data = [to_device(inp, 'cpu') for inp in cali_data] - '''set state first''' - +def prepare_fp_and_quant_model_for_ptq(model, graph_module_list): fp32_model = model fp32_model.eval() if graph_module_list is None: @@ -577,6 +578,47 @@ def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_ enable_quantization(quant_model) torch.cuda.empty_cache() checked_nodes = dict() + return fp32_model, quant_model, nodes, g2node, fp32_modules, quant_modules, topology_order_by_node, qnode2fpnode_dict, checked_nodes + +def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_module_list: list = None): + r""" + Reconsturction for AdaRound, BRECQ, QDrop. + Basic optimization objective: + + .. math:: + + \mathop{\arg\min}_{\mathbf{V}}\ \ || Wx-\tilde{W}x ||_F^2 + \lambda f_{reg}(\mathbf{V}), + + \tilde{W}=s \cdot clip\left( \left\lfloor\dfrac{W}{s}\right\rfloor+h(\mathbf{V}), n, p \right) + + where :math:`h(\mathbf{V}_{i,j})=clip(\sigma(\mathbf{V}_{i,j})(\zeta-\gamma)+\gamma, 0, 1)`, and :math:`f_{reg}(\mathbf{V})=\mathop{\sum}_{i,j}{1-|2h(\mathbf{V}_{i,j})-1|^\beta}`. By annealing on :math:`\beta`, the rounding mask can adapt freely in initial phase and converge to 0 or 1 in later phase. + + Args: + model (torch.nn.Module): a prepared GraphModule to do PTQ + cali_data (List): a list of calibration tensor + config (dict): a config for PTQ reconstruction + graph_module_list (list): a list of model's children modules which need quantization. if this is used, the model is partial quantized; if not, the model is fully quantized. + + >>> sample config : { + pattern: block (str, Available options are [layer, block].) + scale_lr: 4.0e-5 (learning rate for learning step size of activation) + warm_up: 0.2 (0.2 * max_count iters without regularization to floor or ceil) + weight: 0.01 (loss weight for regularization item) + max_count: 20000 (optimization iteration) + b_range: [20,2] (beta decaying range ) + keep_gpu: True (calibration data restore in gpu or cpu) + round_mode: learned_hard_sigmoid (ways to reconstruct the weight, currently only support learned_hard_sigmoid) + prob: 0.5 (dropping probability of QDROP) + } + + """ + # assert model is on cuda + if not config.keep_gpu: + cali_data = [to_device(inp, 'cpu') for inp in cali_data] + '''set state first''' + + fp32_model, quant_model, nodes, g2node, fp32_modules, quant_modules, topology_order_by_node, qnode2fpnode_dict, checked_nodes = \ + prepare_fp_and_quant_model_for_ptq(model, graph_module_list) # setup for the reconstruction block node list block_list = [] @@ -640,6 +682,40 @@ def ptq_reconstruction(model: GraphModule, cali_data: list, config: dict, graph_ for x in layer_node_list: checked_nodes[x] = True + # insert fake module of input/output + fake_module_dict = {'input': [], 'output': []} + for idx, layer_node_list in enumerate(block_list): + input_nodes, output_nodes = get_io_of_block(layer_node_list) + print(idx, input_nodes, output_nodes) + block_list[idx] = [node.name for node in layer_node_list] + for onode in output_nodes: + if onode not in fake_module_dict['output']: + fake_module_dict['output'].append(onode) + for inode in input_nodes: + if inode not in fake_module_dict['input']: + fake_module_dict['input'].append(inode) + # re-build the model + if len(fake_module_dict['input']) == 0 and len(fake_module_dict['output']) == 0: + pass + else: + model_with_fake_module = deepcopy_graphmodule(model) if graph_module_list is None else deepcopy_mixedmodule(model, graph_module_list) + fake_module_dict['input'] = [node for node in model_with_fake_module.graph.nodes if node.name in fake_module_dict['input']] + fake_module_dict['output'] = [node for node in model_with_fake_module.graph.nodes if node.name in fake_module_dict['output']] + model_with_fake_module = insert_fake_modules(model_with_fake_module, fake_module_dict['input'], 'input') + model_with_fake_module = insert_fake_modules(model_with_fake_module, fake_module_dict['output'], 'output') + fp32_model, quant_model, nodes, g2node, fp32_modules, quant_modules, topology_order_by_node, qnode2fpnode_dict, checked_nodes = \ + prepare_fp_and_quant_model_for_ptq(model_with_fake_module, graph_module_list) + for direction in fake_module_dict: + for node in fake_module_dict[direction]: + for idx, layer_node_list in enumerate(block_list): + if node.name in layer_node_list: + block_list[idx].append(node.name + '_fake_module_' + direction) + qname2node = {node.name: node for node in nodes} + for idx, layer_node_name_list in enumerate(block_list): + layer_node_list = [qname2node[node_name] for node_name in layer_node_name_list] + block_list[idx] = layer_node_list + block_list[idx] = sorted(layer_node_list, key=lambda x: topology_order_by_node[x]) + for layer_node_list in block_list: fp32_module = fp32_modules[qnode2fpnode_dict[layer_node_list[-1]]] fp32_all_inps = []