From 91183970bb0beef2db9424822fc221124373e98a Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 29 Jun 2022 16:56:22 +0200 Subject: [PATCH 01/33] [WIP] fx formalism for BERT --- examples/text-classification/run_glue.py | 19 +- optimum/graphcore/fx/transformations.py | 337 ++++++++++++++++++ optimum/graphcore/fx/utils.py | 142 ++++++++ optimum/graphcore/ipu_configuration.py | 2 + optimum/graphcore/modeling_utils.py | 16 +- .../graphcore/models/bert/modeling_bert.py | 311 ++++++++-------- optimum/graphcore/trainer.py | 6 +- 7 files changed, 655 insertions(+), 178 deletions(-) create mode 100644 optimum/graphcore/fx/transformations.py create mode 100644 optimum/graphcore/fx/utils.py diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 1905e7209..afb3d63be 100755 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -469,14 +469,17 @@ def preprocess_function(examples): max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) - labels = torch.tensor(train_dataset[0]["label"]) - if model.config.problem_type is None: - if model.config.num_labels == 1: - model.config.problem_type = "regression" - elif model.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - model.config.problem_type = "single_label_classification" - else: - model.config.problem_type = "multi_label_classification" + # labels = torch.tensor(train_dataset[0]["label"]) + # if model.config.problem_type is None: + # if model.config.num_labels == 1: + # model.config.problem_type = "regression" + # elif model.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + # model.config.problem_type = "single_label_classification" + # else: + # model.config.problem_type = "multi_label_classification" + dummy_input = tokenizer("Used to set the model.config.problem_type", return_tensors="pt") + dummy_input["labels"] = torch.tensor(train_dataset[0]["label"]) + model(**dummy_input) if training_args.do_eval: if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: diff --git a/optimum/graphcore/fx/transformations.py b/optimum/graphcore/fx/transformations.py new file mode 100644 index 000000000..9f4296724 --- /dev/null +++ b/optimum/graphcore/fx/transformations.py @@ -0,0 +1,337 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import re +from typing import TYPE_CHECKING, List, Optional, Union + +import torch + +import poptorch +from optimum.utils import logging + +from ...fx.optimization import ReversibleTransformation, Transformation +from ..modeling_utils import ( + SerializedEmbedding, + SerializedLinear, + get_layer_ipu, + outline_attribute, +) + + +if TYPE_CHECKING: + from torch.fx import GraphModule, Node + +logger = logging.get_logger(__name__) + + +def node_matches_pattern(pattern, node: "Node"): + # TODO: validate that. + name = node.target if isinstance(node.target, str) else node.name.replace("_", ".") + return re.match(pattern, name) + + +class AddPoptorchBlockBase(ReversibleTransformation): + def __init__( + self, block_name: str, layer_ipu: Union[int, List[int]], module_name_regex: str, log_insertions: bool = False + ): + self.block_name = block_name + self.layer_ipu = layer_ipu + self.module_name_regex = re.compile(module_name_regex) if module_name_regex is not None else None + self.log_insertions = log_insertions + + def find_start_nodes(self, graph_module: "GraphModule") -> List["Node"]: + nodes = [] + prefixes = set() + for node in graph_module.graph.nodes: + # TODO: how to match the case where node.target is str + match = re.match(self.module_name_regex, node.target) if isinstance(node.target, str) else None + if match: + prefix = match.group(0) + if prefix not in prefixes: + nodes.append(node) + prefixes.add(match.group(0)) + return nodes + + def insert_start_block_node(self, graph_module: "GraphModule", node: "Node", block_name: str, ipu_id: int): + + if node.op in ["placeholder", "output"]: + raise RuntimeError("You cannot insert a start block op before a placeholder or an output.") + + with graph_module.graph.inserting_before(node): + new_node = graph_module.graph.call_function(poptorch.Block.start, (block_name,), {"ipu_id": ipu_id}) + new_node.parent_module_qualified_name = node.parent_module_qualified_name + new_node.was_transformed = f"{self.__class__.__name__}" + + # def start_block(inputs_to_forward, name, ipu_id): + # poptorch.Block.start(name, ipu_id=ipu_id) + # if len(inputs_to_forward) != 1: + # return inputs_to_forward + # return inputs_to_forward[0] + + # with graph_module.graph.inserting_before(node): + # new_node = graph_module.graph.call_function(start_block, (node.args, block_name, ipu_id)) + # if node.op != "get_attr": + # new_args = [] + # if len(node.args) > 1: + # for idx, _ in enumerate(node.args): + # new_args.append(graph_module.graph.call_function(operator.getitem, (new_node, idx))) + # elif node.args: + # new_args.append(new_node) + # else: + # raise NotImplementedError( + # f"Inserting start block op before a {node.op} that does not take any argument is not supported." + # ) + # node.args = tuple(new_args) + + # new_node.was_transformed = f"{self.__class__.__name__}" + # new_node.orig_node = node + + def get_ipu_for_index(self, index: Optional[int] = None) -> int: + if isinstance(self.layer_ipu, list): + if index is None: + raise ValueError("You must provide an index when layer_ipu is a list.") + return self.layer_ipu[index] + return self.layer_ipu + + def reverse(self, graph_module: "GraphModule") -> "GraphModule": + for node in graph_module.graph.nodes: + if getattr(node, "was_transformed", "") == self.__class__.__name__: + graph_module.graph.erase_node(node) + return graph_module + + +class AddPoptorchBlocksInSeries(AddPoptorchBlockBase): + def transform(self, graph_module: "GraphModule") -> "GraphModule": + nodes = self.find_start_nodes(graph_module) + for index, node in enumerate(nodes): + ipu_id = self.get_ipu_for_index(index) + name = f"{self.block_name} {index}" + if self.log_insertions: + logger.info(f"{name} --> IPU {ipu_id}") + self.insert_start_block_node(graph_module, node, name, ipu_id) + return graph_module + + +class AddPoptorchBlock(AddPoptorchBlockBase): + def transform(self, graph_module: "GraphModule") -> "GraphModule": + start_nodes = self.find_start_nodes(graph_module) + if not start_nodes: + return graph_module + node = start_nodes[0] + ipu_id = self.get_ipu_for_index() + if self.log_insertions: + logger.info(f"{self.block_name} --> IPU {ipu_id}") + self.insert_start_block_node(graph_module, node, f"{self.block_name}", ipu_id) + return graph_module + + +class AutoParallelizeAutoEncoder(ReversibleTransformation): + pass + + +class TupleOutput(Transformation): + def transform(self, graph_module: "GraphModule") -> "GraphModule": + for node in graph_module.graph.nodes: + if node.op == "output": + if isinstance(node.args[0], dict): + node.args = (tuple(node.args[0].values()),) + return graph_module + + +class ClipValues(Transformation): + def __init__(self, clip_value: float): + self.clip_value = clip_value + + def _clip_node_args(self, args): + if isinstance(args, (tuple, list, set)): + return type(args)(self._clip_node_args(arg) for arg in args) + elif isinstance(args, dict): + return {name: self._clip_node_args(arg) for name, arg in args.items()} + elif isinstance(args, (float, int)): + return min(max(args, -self.clip_value), self.clip_value) + else: + return args + + def transform(self, graph_module: "GraphModule") -> "GraphModule": + for node in graph_module.graph.nodes: + if node.op == "call_method" and node.target == "view": + continue + node.args = self._clip_node_args(node.args) + return graph_module + + +class OutlineAttribute(ReversibleTransformation): + def __init__(self, name_regex: str, value: str): + self.name_regex = re.compile(name_regex) + self.value = value + + def transform(self, graph_module: "GraphModule") -> "GraphModule": + first_match, last_match = None, None + for node in graph_module.graph.nodes: + # TODO: how to match the case where node.target is str + match = re.match(self.name_regex, node.target) if isinstance(node.target, str) else False + if match: + if first_match is None: + first_match = node + last_match = node + if first_match is None: + raise RuntimeError(f"Could not find any op matching {self.name_regex} to outline.") + + with graph_module.graph.inserting_before(first_match): + new_node = graph_module.graph.call_function(torch.ops.poptorch.set_attribute, ("__outline", "layer", self.value)) + new_node.parent_module_qualified_name = first_match.parent_module_qualified_name + with graph_module.graph.inserting_after(last_match): + new_node = graph_module.graph.call_function(torch.ops.poptorch.clear_attribute, ("__outline", "layer")) + new_node.parent_module_qualified_name = first_match.parent_module_qualified_name + return graph_module + + def reverse(self, graph_module: "GraphModule") -> "GraphModule": + has_clear_attribute_to_erase = False + for node in graph_module.graph.nodes: + if node.target is torch.ops.poptorch.set_attribute: + if node.args[2] == self.value: + graph_module.graph.erase_node(node) + has_clear_attribute_to_erase = True + if node.target is torch.ops.poptorch.clear_attribute and has_clear_attribute_to_erase: + graph_module.graph.erase_node(node) + has_clear_attribute_to_erase = False + return graph_module + + +class RecomputationCheckpoint(ReversibleTransformation): + def __init__(self, name_regex: str, to_exclude: Optional[str] = None): + self.name_regex = re.compile(name_regex) + self.to_exclude = re.compile(to_exclude) if to_exclude is not None else None + + def find_output_nodes_for_module_name(self, graph_module: "GraphModule", module_qualified_name: str): + nodes_in_module = set() + first_match = False + for n in graph_module.graph.nodes: + starts_with_module_qualified_name = getattr(n, "parent_module_qualified_name", "").startswith(module_qualified_name) + print(getattr(n, "parent_module_qualified_name", "")) + print(starts_with_module_qualified_name) + if not first_match and not starts_with_module_qualified_name: + continue + elif not first_match and starts_with_module_qualified_name: + first_match = True + nodes_in_module.add(n) + elif first_match and starts_with_module_qualified_name: + nodes_in_module.add(n) + else: + break + nodes_in_module = {n for n in nodes_in_module if set(n.users.keys()) & nodes_in_module} + return [n for n in nodes_in_module if set(n.users.keys()) - nodes_in_module] + + def transform(self, graph_module: "GraphModule") -> "GraphModule": + matched_module_names = collections.OrderedDict() + for node in graph_module.graph.nodes: + match = re.match(self.name_regex, getattr(node, "parent_module_qualified_name", "")) + to_exclude = False + if self.to_exclude is not None: + to_exclude = re.match(self.to_exclude, getattr(node, "parent_module_qualified_name", "")) + if match and not to_exclude: + matched_module_names[match.group(0)] = None + + output_nodes = [] + for qualified_name in matched_module_names.keys(): + print(qualified_name) + output_nodes += self.find_output_nodes_for_module_name(graph_module, qualified_name) + + for output in output_nodes: + with graph_module.graph.inserting_after(output): + print("output", output) + recomputation_node = graph_module.graph.call_function(poptorch.recomputationCheckpoint) + output.replace_all_uses_with(recomputation_node) + recomputation_node.args = (output,) + + return graph_module + + def reverse(self, graph_module: "GraphModule") -> "GraphModule": + for node in graph_module.graph.nodes: + if node.target == poptorch.recomputationCheckpoint: + node.replace_all_uses_with(node.args[0]) + graph_module.graph.erase_node(node) + return graph_module + + +class VocabEmbeddingToSerializedEmbedding(ReversibleTransformation): + def __init__(self, name_regex: Optional[str] = None): + self.name_regex = re.compile(name_regex) if name_regex else None + + def transform(self, graph_module: "GraphModule") -> "GraphModule": + embedding_nodes = [] + for node in graph_module.graph.nodes: + if node.op != "call_module": + continue + match = re.match(self.name_regex, node.target) if self.name_regex is not None else True + if match and isinstance(graph_module.get_submodule(node.target), torch.nn.Embedding): + embedding_nodes.append(node) + + # We assume the vocab embedding to be the embedding with the maximum number of embeddings. + if not embedding_nodes: + raise RuntimeError("Could not find any embedding node") + + embedding_node = max(embedding_nodes, key=lambda node: graph_module.get_submodule(node.target).num_embeddings) + parent_fully_qualified_name, embedding_name = embedding_node.target.rsplit(".", maxsplit=1) + new_embedding = SerializedEmbedding( + graph_module.get_submodule(embedding_node.target), graph_module.ipu_config.embedding_serialization_factor + ) + setattr(graph_module.get_submodule(parent_fully_qualified_name), embedding_name, new_embedding) + embedding_node.was_transformed = "VocabEmbeddingToSerializedEmbedding" + + return graph_module + + def reverse(self, graph_module: "GraphModule") -> "GraphModule": + for node in graph_module.graph.nodes: + if getattr(node, "was_transformed", "") == "VocabEmbeddingToSerializedEmbedding": + parent_fully_qualified_name, embedding_name = node.target.rsplit(".", maxsplit=1) + setattr( + graph_module.get_submodule(parent_fully_qualified_name), + embedding_name, + graph_module.get_submodule(node.target).deserialize(), + ) + break + return graph_module + + +class LinearToSerializedLinear(ReversibleTransformation): + def __init__(self, name_regex: str): + self.name_regex = re.compile(name_regex) if name_regex else None + + def transform(self, graph_module: "GraphModule") -> "GraphModule": + for node in graph_module.graph.nodes: + if node.op != "call_module": + continue + match = re.match(self.name_regex, node.target) if self.name_regex is not None else True + if match and isinstance(graph_module.get_submodule(node.target), torch.nn.Linear): + linear = graph_module.get_submodule(node.target) + serialized_linear = SerializedLinear( + graph_module.config.hidden_size, + graph_module.config.vocab_size, + graph_module.ipu_config.embedding_serialization_factor, + bias=linear.bias is not None, + mode=poptorch.MatMulSerializationMode.OutputChannels, + ) + serialized_linear.load_state_dict(linear.state_dict()) + parent_fully_qualified_name, linear_name = node.target.rsplit(".", maxsplit=1) + setattr(graph_module.get_submodule(parent_fully_qualified_name), linear_name, serialized_linear) + graph_module.tie_weights() + return graph_module + + def reverse(self, graph_module: "GraphModule") -> "GraphModule": + for node in graph_module.graph.nodes: + if node.op == "call_module" and isinstance(graph_module.get_submodule(node.target), SerializedLinear): + graph_module.get_submodule(node.target).__class__ = torch.nn.Linear + return graph_module diff --git a/optimum/graphcore/fx/utils.py b/optimum/graphcore/fx/utils.py new file mode 100644 index 000000000..03e8e51f4 --- /dev/null +++ b/optimum/graphcore/fx/utils.py @@ -0,0 +1,142 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import TYPE_CHECKING, Callable, Optional, List + +import torch + +import transformers +from transformers.utils.fx import HFTracer, check_if_model_is_supported, get_concrete_args, _gen_constructor_wrapper + +from ..modeling_utils import PipelineMixin + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + +class PipelinedTracer(HFTracer): + def __init__(self, autowrap_modules=(math,), autowrap_functions=()): + super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions) + self.ops_to_wrap = [] + self.current_module_qualified_name = ["root"] + + def register_op_to_wrap(self, name, wrapper, orig_op): + self.ops_to_wrap.append((name, wrapper, orig_op)) + + def _patch_op(self, op_name: str, op_patched: "Callable"): + names = op_name.split(".") + module_names = names[1:-1] + attr_name = names[-1] + mod = torch + for module_name in module_names: + mod = getattr(mod, module_name) + setattr(mod, attr_name, op_patched) + + def wrap_ops(self): + for name, wrapper, _ in self.ops_to_wrap: + self._patch_op(name, wrapper) + + def unwrap_ops(self): + for name, _, orig_op in self.ops_to_wrap: + self._patch_op(name, orig_op) + + def proxy(self, node): + # Would be better to update the created node in TracerBase.create_node, but this method has less arguments, so + # it is easier to use this one, and equivalent. + node.parent_module_qualified_name = self.current_module_qualified_name[-1] + proxy = super().proxy(node) + return proxy + + def call_module(self, m, forward, args, kwargs): + # Could be done in a "cleaner" fashion by inlining the content of Tracer.call_module. + # Preferred to inherint from it and do it that way instead. + module_qualified_name = self.path_of_module(m) + is_leaf_module = self.is_leaf_module(m, module_qualified_name) + if not is_leaf_module: + self.current_module_qualified_name.append(module_qualified_name) + self.orig_forward = forward + proxy = super().call_module(m, forward, args, kwargs) + if not is_leaf_module: + self.current_module_qualified_name.pop(-1) + return proxy + + +def symbolic_trace_with_pipelined_tracer( + model: PipelineMixin, + input_names: Optional[List[str]] = None, +) -> torch.fx.GraphModule: + + """ + Performs symbolic tracing on the model. + + Args: + model ([`PretrainedModel`]): + The model to trace. + input_names (`List[str]`, *optional*): + The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead. + Returns: + `torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model. + """ + if input_names is None: + input_names = model.dummy_inputs.keys() + + input_names = list(input_names) + concrete_args = get_concrete_args(model, input_names) + + # Tracing. + tracer = PipelinedTracer() + for wrap_info in model.get_ops_to_wrap_for_tracing(): + tracer.register_op_to_wrap(*wrap_info) + tracer.wrap_ops() + traced_graph = tracer.trace(model, concrete_args=concrete_args) + tracer.unwrap_ops() + + traced = torch.fx.GraphModule(model, traced_graph) + + # The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus + # _generate_dummy_input, where the model class is needed. + traced.class_for_deserialization = model.__class__ + traced.device = model.device + + for name, attr in vars(model).items(): + setattr(traced, name, getattr(traced, name, attr)) + + return traced + + +def symbolic_trace_pipelined_model(pipelined_model: PipelineMixin) -> PipelineMixin: + if isinstance(pipelined_model, torch.fx.GraphModule): + return pipelined_model + + transformers_class = None + for base in pipelined_model.__class__.__bases__: + if transformers.PreTrainedModel in base.__mro__: + transformers_class = base + break + + # Trick to make HFTracer._generate_dummy_input work with the pipelined class. + # This attribute will be set properly in symbolic_trace_with_pipelined_tracer once tracing is done. + pipelined_model.class_for_deserialization = transformers_class + + traced = symbolic_trace_with_pipelined_tracer( + pipelined_model, input_names=pipelined_model.input_names + ) + + type_ = type(f"Traced{pipelined_model.__class__.__name__}", (torch.fx.GraphModule, pipelined_model.__class__), {}) + traced.__class__ = type_ + + # traced.ipu_config = pipelined_model.ipu_config + # traced._hooks = pipelined_model._hooks + return traced diff --git a/optimum/graphcore/ipu_configuration.py b/optimum/graphcore/ipu_configuration.py index 868960691..ac7ca8950 100644 --- a/optimum/graphcore/ipu_configuration.py +++ b/optimum/graphcore/ipu_configuration.py @@ -159,6 +159,8 @@ def __init__(self, **kwargs): # TODO: remove this if unnecessary. self.execute_encoder_on_cpu_for_generation = kwargs.pop("execute_encoder_on_cpu_for_generation", False) + self.log_insertions = kwargs.pop("log_insertions", False) + def _prepare_config_attribute_for_pod_type( self, config_attribute_name: str, config_attribute: Union[Any, Dict[str, Any]], pod_type: Optional[str] ) -> Any: diff --git a/optimum/graphcore/modeling_utils.py b/optimum/graphcore/modeling_utils.py index b4c324cd7..c16e3b2f1 100644 --- a/optimum/graphcore/modeling_utils.py +++ b/optimum/graphcore/modeling_utils.py @@ -13,8 +13,8 @@ # limitations under the License. import copy -from inspect import signature -from typing import Any, Dict, Optional, Tuple +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.nn.functional as F @@ -141,6 +141,16 @@ def ipu_config(self, value: IPUConfig): raise TypeError(f"ipu_config must be an instance of IPUConfig, but {type(value)} was provided") self._ipu_config = value + def get_ops_to_wrap_for_tracing(self) -> List[Tuple[str, Callable, Callable]]: + return [] + + def get_transformations(self): + raise NotImplementedError("You need to implement get_transformations.") + + @property + def input_names(self): + return list(inspect.signature(self.forward).parameters.keys()) + def parallelize(self): """Transforms the model to run in an IPU pipeline.""" self._hooks = [] @@ -211,7 +221,7 @@ def get_encoder( def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: inputs = super().prepare_inputs_for_generation(input_ids, **kwargs) - return {k: v for k, v in inputs.items() if k in signature(self._forward_for_generate).parameters} + return {k: v for k, v in inputs.items() if k in inspect.signature(self._forward_for_generate).parameters} def get_layer_ipu(layers_per_ipu): diff --git a/optimum/graphcore/models/bert/modeling_bert.py b/optimum/graphcore/models/bert/modeling_bert.py index a2055977c..ec8b39938 100644 --- a/optimum/graphcore/models/bert/modeling_bert.py +++ b/optimum/graphcore/models/bert/modeling_bert.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import Optional, Tuple, Union @@ -29,20 +30,30 @@ BertForSequenceClassification, BertForTokenClassification, ) -from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput -from transformers.models.bert.modeling_bert import BertForPreTrainingOutput, BertSelfAttention +from transformers.utils.fx import _gen_constructor_wrapper + +from ....fx.optimization import ChangeTrueDivToMulByInverse, FuseBiasInLinear, MergeLinears, compose +from ...fx.transformations import ( + AddPoptorchBlock, + AddPoptorchBlocksInSeries, + ClipValues, + LinearToSerializedLinear, + OutlineAttribute, + RecomputationCheckpoint, + TupleOutput, + VocabEmbeddingToSerializedEmbedding, +) +from ...fx.utils import symbolic_trace_pipelined_model from ...modeling_utils import ( OnehotGather, PipelineMixin, - SerializedEmbedding, SerializedLinear, get_layer_ipu, outline_attribute, recomputation_checkpoint, register, ) -from .bert_fused_attention import BertFusedSelfAttention logger = logging.get_logger(__name__) @@ -63,6 +74,31 @@ def __init__(self, config): super().__init__(config) self.gather_indices = OnehotGather() + def get_ops_to_wrap_for_tracing(self): + return [ + ("torch.topk", *_gen_constructor_wrapper(torch.topk)), + ("torch.nn.functional.one_hot", *_gen_constructor_wrapper(torch.nn.functional.one_hot)), + ] + + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + transformations = [ + ChangeTrueDivToMulByInverse(), + MergeLinears(), + # FuseBiasInLinear(), + AddPoptorchBlock( + "Embedding", layer_ipu=0, module_name_regex="bert.embeddings", log_insertions=log_insertions + ), + OutlineAttribute("bert.embeddings.LayerNorm", "Embedding"), + AddPoptorchBlocksInSeries( + "Encoder", layer_ipu, module_name_regex=r"bert.encoder.layer.[0-9]+", log_insertions=log_insertions + ), + AddPoptorchBlock("Pooler Output", 0, "bert.pooler", log_insertions=log_insertions), + AddPoptorchBlock("Classifier Output", 0, "cls", log_insertions=log_insertions), + ] + return transformations + def parallelize(self): """ Transform the model to run in an IPU pipeline. @@ -72,48 +108,15 @@ def parallelize(self): - Adds recomputation checkpoints """ super().parallelize() - - # Use faster fused-qkv self-attention - for layer in self.bert.encoder.layer: - layer.attention.self.__class__ = BertFusedSelfAttention - - if self.ipu_config.embedding_serialization_factor > 1: - serialized_decoder = SerializedLinear( - self.config.hidden_size, - self.config.vocab_size, - self.ipu_config.embedding_serialization_factor, - bias=True, - mode=poptorch.MatMulSerializationMode.OutputChannels, - ) - serialized_decoder.load_state_dict(self.cls.predictions.decoder.state_dict()) - self.cls.predictions.decoder = serialized_decoder - self.tie_weights() - - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - - logger.info("-------------------- Device Allocation --------------------") - logger.info("Embedding --> IPU 0") - self.bert.embeddings = poptorch.BeginBlock(self.bert.embeddings, "Embedding", ipu_id=0) - # Preventing the embeddings.LayerNorm from being outlined with the encoder.layer.LayerNorm - # improves the tile mapping of the pipeline stashes - hs = outline_attribute(self.bert.embeddings.LayerNorm, "embeddings") - self._hooks.extend(hs) - - for index, layer in enumerate(self.bert.encoder.layer): - ipu = layer_ipu[index] - if self.ipu_config.recompute_checkpoint_every_layer: - h = recomputation_checkpoint(layer) - self._hooks.append(h) - self.bert.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) - logger.info(f"Encoder {index:<2} --> IPU {ipu}") - - logger.info("Pooler --> IPU 0") - self.bert.pooler = poptorch.BeginBlock(self.bert.pooler, "Pooler", ipu_id=0) - - logger.info("Classifier --> IPU 0") - self.cls = poptorch.BeginBlock(self.cls, "Classifier", ipu_id=0) - logger.info("-----------------------------------------------------------") - return self + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + # if self.ipu_config.embedding_serialization_factor > 1: + # transformations.append(LinearToSerializedLinear("cls.predictions.decoder")) + composition = compose(*transformations) + non_reversible_composition = compose(ClipValues(1e4), TupleOutput()) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced def deparallelize(self): """ @@ -122,19 +125,11 @@ def deparallelize(self): compatible with the original model. """ super().deparallelize() - - for layer in self.bert.encoder.layer: - layer.attention.self.__class__ = BertSelfAttention - + transformations = self.get_transformations() if self.ipu_config.embedding_serialization_factor > 1: - decoder = nn.Linear( - self.config.hidden_size, - self.config.vocab_size, - bias=True, - ) - decoder.load_state_dict(self.cls.predictions.decoder.state_dict()) - self.cls.predictions.decoder = decoder - self.tie_weights() + transformations.append(LinearToSerializedLinear("cls.predictions.decoder")) + composition = compose(*transformations) + self = composition(self, reverse=True) return self def _init_weights(self, module): @@ -237,6 +232,32 @@ def __init__(self, config): super().__init__(config) self.gather_indices = OnehotGather() + def get_ops_to_wrap_for_tracing(self): + return [ + ("torch.topk", *_gen_constructor_wrapper(torch.topk)), + ("torch.nn.functional.one_hot", *_gen_constructor_wrapper(torch.nn.functional.one_hot)), + ] + + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + transformations = [ + AddPoptorchBlock( + "Embedding", layer_ipu=0, module_name_regex="bert.embeddings", log_insertions=log_insertions + ), + OutlineAttribute("bert.embeddings.LayerNorm", "Embedding"), + AddPoptorchBlocksInSeries( + "Encoder", layer_ipu, module_name_regex=r"bert.encoder.layer.[0-9]+", log_insertions=log_insertions + ), + AddPoptorchBlock("Classifier Output", 0, "cls", log_insertions=log_insertions), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations.append(RecomputationCheckpoint("bert.encoder.layer.[0-9]+", to_exclude=f"bert.encoder.layer.{self.config.num_hidden_layers - 1}")) + if self.ipu_config.embedding_serialization_factor > 1: + transformations.append(LinearToSerializedLinear("cls.predictions.decoder")) + transformations += [ChangeTrueDivToMulByInverse(), MergeLinears()] + return transformations + def parallelize(self): """ Transform the model to run in an IPU pipeline. @@ -246,45 +267,16 @@ def parallelize(self): - Adds recomputation checkpoints """ super().parallelize() + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + composition = compose(*transformations) + non_reversible_composition = compose(ClipValues(1e4), TupleOutput()) - # Use faster fused-qkv self-attention - for layer in self.bert.encoder.layer: - layer.attention.self.__class__ = BertFusedSelfAttention - - if self.ipu_config.embedding_serialization_factor > 1: - serialized_decoder = SerializedLinear( - self.config.hidden_size, - self.config.vocab_size, - self.ipu_config.embedding_serialization_factor, - bias=True, - mode=poptorch.MatMulSerializationMode.OutputChannels, - ) - serialized_decoder.load_state_dict(self.cls.predictions.decoder.state_dict()) - self.cls.predictions.decoder = serialized_decoder - self.tie_weights() - - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + traced = composition(traced) + traced = non_reversible_composition(traced) + import pdb; pdb.set_trace() - logger.info("-------------------- Device Allocation --------------------") - logger.info("Embedding --> IPU 0") - self.bert.embeddings = poptorch.BeginBlock(self.bert.embeddings, "Embedding", ipu_id=0) - # Preventing the embeddings.LayerNorm from being outlined with the encoder.layer.LayerNorm - # improves the tile mapping of the pipeline stashes - hs = outline_attribute(self.bert.embeddings.LayerNorm, "embeddings") - self._hooks.extend(hs) - - for index, layer in enumerate(self.bert.encoder.layer): - ipu = layer_ipu[index] - if self.ipu_config.recompute_checkpoint_every_layer: - h = recomputation_checkpoint(layer) - self._hooks.append(h) - self.bert.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) - logger.info(f"Encoder {index:<2} --> IPU {ipu}") - - logger.info("Classifier --> IPU 0") - self.cls = poptorch.BeginBlock(self.cls, "Classifier", ipu_id=0) - logger.info("-----------------------------------------------------------") - return self + return traced def deparallelize(self): """ @@ -294,18 +286,10 @@ def deparallelize(self): """ super().deparallelize() - for layer in self.bert.encoder.layer: - layer.attention.self.__class__ = BertSelfAttention + transformations = self.get_transformations() + composition = compose(*transformations) + self = composition(self, reverse=True) - if self.ipu_config.embedding_serialization_factor > 1: - decoder = nn.Linear( - self.config.hidden_size, - self.config.vocab_size, - bias=True, - ) - decoder.load_state_dict(self.cls.predictions.decoder.state_dict()) - self.cls.predictions.decoder = decoder - self.tie_weights() return self def forward( @@ -383,6 +367,30 @@ def forward( class BertPipelineMixin(PipelineMixin): + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + last_ipu = len(self.ipu_config.layers_per_ipu) - 1 + transformations = [ + ChangeTrueDivToMulByInverse(), + MergeLinears(), + AddPoptorchBlock( + "Embedding", layer_ipu=0, module_name_regex="bert.embeddings", log_insertions=log_insertions + ), + OutlineAttribute("bert.embeddings.LayerNorm", "Embedding"), + AddPoptorchBlocksInSeries( + "Encoder", layer_ipu, module_name_regex=r"bert.encoder.layer.[0-9]+", log_insertions=log_insertions + ), + # Only one of the following AddPoptorchBlock, will actually add a block. + AddPoptorchBlock("Classifier Output", last_ipu, "classifier", log_insertions=log_insertions), + AddPoptorchBlock("QA Outputs", last_ipu, "qa_outputs", log_insertions=log_insertions), + ] + return transformations + + @property + def input_names(self): + return ["input_ids", "attention_mask", "token_type_ids", "labels"] + def parallelize(self): """ Transform the model to run in an IPU pipeline. @@ -393,30 +401,26 @@ def parallelize(self): """ super().parallelize() - # Use faster fused-qkv self-attention - for layer in self.bert.encoder.layer: - layer.attention.self.__class__ = BertFusedSelfAttention + # if self.ipu_config.recompute_checkpoint_every_layer: + # for layer in self.bert.encoder.layer[:-1]: + # h = recomputation_checkpoint(layer) + # self._hooks.append(h) - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + traced = symbolic_trace_pipelined_model(self) - logger.info("-------------------- Device Allocation --------------------") - logger.info("Embedding --> IPU 0") - if self.ipu_config.embedding_serialization_factor > 1: - self.bert.embeddings.word_embeddings = SerializedEmbedding( - self.bert.embeddings.word_embeddings, self.ipu_config.embedding_serialization_factor - ) - self.bert.embeddings = poptorch.BeginBlock(self.bert.embeddings, "Embedding", ipu_id=0) - hs = outline_attribute(self.bert.embeddings.LayerNorm, "embedding") - self._hooks.extend(hs) - - for index, layer in enumerate(self.bert.encoder.layer): - ipu = layer_ipu[index] - if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: - h = recomputation_checkpoint(layer) - self._hooks.append(h) - self.bert.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) - logger.info(f"Encoder {index:<2} --> IPU {ipu}") - return self + transformations = self.get_transformations() + + if traced.ipu_config.embedding_serialization_factor > 1: + transformations.append(VocabEmbeddingToSerializedEmbedding()) + + composition = compose(*transformations) + + non_reversible_composition = compose(ClipValues(1e4), TupleOutput()) + + traced = composition(traced) + traced = non_reversible_composition(traced) + + return traced def deparallelize(self): """ @@ -426,12 +430,16 @@ def deparallelize(self): """ super().deparallelize() - for layer in self.bert.encoder.layer: - layer.attention.self.__class__ = BertSelfAttention - - # Deserialize the serialized word embedding + transformations = self.get_transformations() if self.ipu_config.embedding_serialization_factor > 1: - self.bert.embeddings.word_embeddings = self.bert.embeddings.word_embeddings.deserialize() + transformations.append(VocabEmbeddingToSerializedEmbedding()) + + # if self.ipu_config.recompute_checkpoint_every_layer: + # transformations.append(RecomputationCheckpoint()) + + composition = compose(*transformations) + self = composition(self, reverse=True) + return self @@ -445,14 +453,7 @@ class PipelinedBertForSequenceClassification(BertForSequenceClassification, Bert model = PipelinedBertForSequenceClassification(config).parallelize().half() ``` """ - - def parallelize(self): - super().parallelize() - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"Classifier Output --> IPU {last_ipu}") - self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self + pass @register(BertForMultipleChoice) @@ -465,14 +466,7 @@ class PipelinedBertForMultipleChoice(BertForMultipleChoice, BertPipelineMixin): model = PipelinedBertForMultipleChoice(config).parallelize().half() ``` """ - - def parallelize(self): - super().parallelize() - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"Classifier Output --> IPU {last_ipu}") - self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self + pass @register(BertForTokenClassification) @@ -485,14 +479,7 @@ class PipelinedBertForTokenClassification(BertForTokenClassification, BertPipeli model = PipelinedBertForTokenClassification(config).parallelize().half() ``` """ - - def parallelize(self): - super().parallelize() - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"Classifier Output --> IPU {last_ipu}") - self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self + pass @register(BertForQuestionAnswering) @@ -506,13 +493,9 @@ class PipelinedBertForQuestionAnswering(BertForQuestionAnswering, BertPipelineMi ``` """ - def parallelize(self): - super().parallelize() - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"QA Outputs --> IPU {last_ipu}") - self.qa_outputs = poptorch.BeginBlock(self.qa_outputs, "QA Outputs", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self + @property + def input_names(self): + return ["input_ids", "attention_mask", "token_type_ids", "start_positions", "end_positions"] def forward( self, diff --git a/optimum/graphcore/trainer.py b/optimum/graphcore/trainer.py index 6bf5f2722..ed72b6557 100644 --- a/optimum/graphcore/trainer.py +++ b/optimum/graphcore/trainer.py @@ -302,7 +302,7 @@ def __init__( self.eval_data_collator = data_collator_wrapper(self.eval_data_collator) self.model = to_pipelined(model, self.ipu_config, force=force_to_pipelined) - self.model.parallelize() + self.model = self.model.parallelize() self.original_model = model @@ -1628,9 +1628,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: rng_state = torch.random.get_rng_state() - self.model.deparallelize() + self.model = self.model.deparallelize() self.model.save_pretrained(output_dir, state_dict=state_dict) - self.model.parallelize() + self.model = self.model.parallelize() torch.random.set_rng_state(rng_state) if self.tokenizer is not None: From b1e23f3f66bcd100237577f94bf99446d78e1978 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 7 Jul 2022 12:20:50 +0200 Subject: [PATCH 02/33] Transformations are working --- optimum/graphcore/fx/transformations.py | 152 ++++++++++++------ optimum/graphcore/fx/utils.py | 24 ++- .../graphcore/models/bert/modeling_bert.py | 120 +++++++------- optimum/graphcore/trainer.py | 12 +- 4 files changed, 175 insertions(+), 133 deletions(-) diff --git a/optimum/graphcore/fx/transformations.py b/optimum/graphcore/fx/transformations.py index 9f4296724..2f6c988f2 100644 --- a/optimum/graphcore/fx/transformations.py +++ b/optimum/graphcore/fx/transformations.py @@ -14,7 +14,7 @@ # limitations under the License. import collections import re -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Callable, List, Optional, Union import torch @@ -22,12 +22,7 @@ from optimum.utils import logging from ...fx.optimization import ReversibleTransformation, Transformation -from ..modeling_utils import ( - SerializedEmbedding, - SerializedLinear, - get_layer_ipu, - outline_attribute, -) +from ..modeling_utils import SerializedEmbedding, SerializedLinear if TYPE_CHECKING: @@ -42,21 +37,28 @@ def node_matches_pattern(pattern, node: "Node"): return re.match(pattern, name) +def parent_module_qualified_name(node: "Node") -> str: + return getattr(node, "parent_module_qualified_name", "") + + class AddPoptorchBlockBase(ReversibleTransformation): + """ + Base class that provide useful methods for inserting poptorch blocks in the model. + """ + def __init__( - self, block_name: str, layer_ipu: Union[int, List[int]], module_name_regex: str, log_insertions: bool = False + self, block_name: str, layer_ipu: Union[int, List[int]], name_regex: str, log_insertions: bool = False ): self.block_name = block_name self.layer_ipu = layer_ipu - self.module_name_regex = re.compile(module_name_regex) if module_name_regex is not None else None + self.name_regex = re.compile(name_regex) self.log_insertions = log_insertions def find_start_nodes(self, graph_module: "GraphModule") -> List["Node"]: nodes = [] prefixes = set() for node in graph_module.graph.nodes: - # TODO: how to match the case where node.target is str - match = re.match(self.module_name_regex, node.target) if isinstance(node.target, str) else None + match = re.match(self.name_regex, parent_module_qualified_name(node)) if match: prefix = match.group(0) if prefix not in prefixes: @@ -74,30 +76,6 @@ def insert_start_block_node(self, graph_module: "GraphModule", node: "Node", blo new_node.parent_module_qualified_name = node.parent_module_qualified_name new_node.was_transformed = f"{self.__class__.__name__}" - # def start_block(inputs_to_forward, name, ipu_id): - # poptorch.Block.start(name, ipu_id=ipu_id) - # if len(inputs_to_forward) != 1: - # return inputs_to_forward - # return inputs_to_forward[0] - - # with graph_module.graph.inserting_before(node): - # new_node = graph_module.graph.call_function(start_block, (node.args, block_name, ipu_id)) - # if node.op != "get_attr": - # new_args = [] - # if len(node.args) > 1: - # for idx, _ in enumerate(node.args): - # new_args.append(graph_module.graph.call_function(operator.getitem, (new_node, idx))) - # elif node.args: - # new_args.append(new_node) - # else: - # raise NotImplementedError( - # f"Inserting start block op before a {node.op} that does not take any argument is not supported." - # ) - # node.args = tuple(new_args) - - # new_node.was_transformed = f"{self.__class__.__name__}" - # new_node.orig_node = node - def get_ipu_for_index(self, index: Optional[int] = None) -> int: if isinstance(self.layer_ipu, list): if index is None: @@ -113,6 +91,10 @@ def reverse(self, graph_module: "GraphModule") -> "GraphModule": class AddPoptorchBlocksInSeries(AddPoptorchBlockBase): + """ + Adds poptorch blocks in series, to all the layers matching name_regex. + """ + def transform(self, graph_module: "GraphModule") -> "GraphModule": nodes = self.find_start_nodes(graph_module) for index, node in enumerate(nodes): @@ -125,6 +107,10 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule": class AddPoptorchBlock(AddPoptorchBlockBase): + """ + Adds a poptorch block before the first node (layer) matching name_regex. + """ + def transform(self, graph_module: "GraphModule") -> "GraphModule": start_nodes = self.find_start_nodes(graph_module) if not start_nodes: @@ -142,6 +128,10 @@ class AutoParallelizeAutoEncoder(ReversibleTransformation): class TupleOutput(Transformation): + """ + Transforms the output of the model to a tuple, if it is a dict, and does not nothing otherwise. + """ + def transform(self, graph_module: "GraphModule") -> "GraphModule": for node in graph_module.graph.nodes: if node.op == "output": @@ -151,8 +141,22 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule": class ClipValues(Transformation): - def __init__(self, clip_value: float): - self.clip_value = clip_value + """ + Clips values to make them fall into [min_value, max_value]. + This is useful for fp16 for instance. + """ + + def __init__( + self, + min_value: float, + max_value: float, + include_targets: Optional[List[Union[str, Callable]]] = None, + exclude_targets: Optional[List[Union[str, Callable]]] = None, + ): + self.min_value = min_value + self.max_value = max_value + self.include_targets = include_targets if include_targets is not None else [] + self.exclude_targets = exclude_targets if exclude_targets is not None else [] def _clip_node_args(self, args): if isinstance(args, (tuple, list, set)): @@ -160,19 +164,42 @@ def _clip_node_args(self, args): elif isinstance(args, dict): return {name: self._clip_node_args(arg) for name, arg in args.items()} elif isinstance(args, (float, int)): - return min(max(args, -self.clip_value), self.clip_value) + return min(max(args, self.min_value), self.max_value) else: return args def transform(self, graph_module: "GraphModule") -> "GraphModule": for node in graph_module.graph.nodes: - if node.op == "call_method" and node.target == "view": + if self.include_targets and node.target not in self.include_targets: + continue + if node.target in self.exclude_targets: continue node.args = self._clip_node_args(node.args) return graph_module +class ClipValuesSymmetric(ClipValues): + """ + Clips values to make them fall into [-clip_value, clip_value]. + This is useful for fp16 for instance. + """ + + def __init__( + self, + clip_value: float, + include_targets: Optional[List[Union[str, Callable]]] = None, + exclude_targets: Optional[List[Union[str, Callable]]] = None, + ): + if clip_value < 0: + raise ValueError(f"The provided clip value must be equal or greater than 0, but here {clip_value}.") + return super().__init__(-clip_value, clip_value, exclude_targets=exclude_targets) + + class OutlineAttribute(ReversibleTransformation): + """ + Adds an attribute to a module. This attribute will be used when comparing operation equivalence in outlining. + """ + def __init__(self, name_regex: str, value: str): self.name_regex = re.compile(name_regex) self.value = value @@ -180,7 +207,6 @@ def __init__(self, name_regex: str, value: str): def transform(self, graph_module: "GraphModule") -> "GraphModule": first_match, last_match = None, None for node in graph_module.graph.nodes: - # TODO: how to match the case where node.target is str match = re.match(self.name_regex, node.target) if isinstance(node.target, str) else False if match: if first_match is None: @@ -190,7 +216,9 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule": raise RuntimeError(f"Could not find any op matching {self.name_regex} to outline.") with graph_module.graph.inserting_before(first_match): - new_node = graph_module.graph.call_function(torch.ops.poptorch.set_attribute, ("__outline", "layer", self.value)) + new_node = graph_module.graph.call_function( + torch.ops.poptorch.set_attribute, ("__outline", "layer", self.value) + ) new_node.parent_module_qualified_name = first_match.parent_module_qualified_name with graph_module.graph.inserting_after(last_match): new_node = graph_module.graph.call_function(torch.ops.poptorch.clear_attribute, ("__outline", "layer")) @@ -211,6 +239,10 @@ def reverse(self, graph_module: "GraphModule") -> "GraphModule": class RecomputationCheckpoint(ReversibleTransformation): + """ + Annotates the output of a module to be checkpointed instead of recomputed. + """ + def __init__(self, name_regex: str, to_exclude: Optional[str] = None): self.name_regex = re.compile(name_regex) self.to_exclude = re.compile(to_exclude) if to_exclude is not None else None @@ -218,40 +250,46 @@ def __init__(self, name_regex: str, to_exclude: Optional[str] = None): def find_output_nodes_for_module_name(self, graph_module: "GraphModule", module_qualified_name: str): nodes_in_module = set() first_match = False - for n in graph_module.graph.nodes: - starts_with_module_qualified_name = getattr(n, "parent_module_qualified_name", "").startswith(module_qualified_name) - print(getattr(n, "parent_module_qualified_name", "")) - print(starts_with_module_qualified_name) + # Some nodes are created by calling a module that was created before in the model. This means that these nodes + # parent_module_qualified_name attributes will be "from the past", but we still want to consider them inside + # the current module, since they are called here. + modules_from_the_past = set() + for node in graph_module.graph.nodes: + name = parent_module_qualified_name(node) + starts_with_module_qualified_name = name.startswith(module_qualified_name) if not first_match and not starts_with_module_qualified_name: - continue + pass elif not first_match and starts_with_module_qualified_name: first_match = True - nodes_in_module.add(n) + nodes_in_module.add(node) elif first_match and starts_with_module_qualified_name: - nodes_in_module.add(n) + nodes_in_module.add(node) + elif first_match and name in modules_from_the_past: + # The module under which this node was created belongs to somewhere before in the hierarchy, but we + # consider this node to be part of this module since it's being used here. + nodes_in_module.add(node) else: break + modules_from_the_past.add(name) nodes_in_module = {n for n in nodes_in_module if set(n.users.keys()) & nodes_in_module} return [n for n in nodes_in_module if set(n.users.keys()) - nodes_in_module] def transform(self, graph_module: "GraphModule") -> "GraphModule": matched_module_names = collections.OrderedDict() for node in graph_module.graph.nodes: - match = re.match(self.name_regex, getattr(node, "parent_module_qualified_name", "")) + match = re.match(self.name_regex, parent_module_qualified_name(node)) to_exclude = False if self.to_exclude is not None: - to_exclude = re.match(self.to_exclude, getattr(node, "parent_module_qualified_name", "")) + to_exclude = re.match(self.to_exclude, parent_module_qualified_name(node)) if match and not to_exclude: matched_module_names[match.group(0)] = None output_nodes = [] for qualified_name in matched_module_names.keys(): - print(qualified_name) output_nodes += self.find_output_nodes_for_module_name(graph_module, qualified_name) for output in output_nodes: with graph_module.graph.inserting_after(output): - print("output", output) recomputation_node = graph_module.graph.call_function(poptorch.recomputationCheckpoint) output.replace_all_uses_with(recomputation_node) recomputation_node.args = (output,) @@ -267,6 +305,12 @@ def reverse(self, graph_module: "GraphModule") -> "GraphModule": class VocabEmbeddingToSerializedEmbedding(ReversibleTransformation): + """ + Transforms the embedding layer matching name_regex to a SerializedEmbedding layer. + If no name_regex is provided, all the embeddings will be detected, but in any case, only the embedding with the + biggest number of embeddings will be transformed (this is usually the one containing the vocabulary). + """ + def __init__(self, name_regex: Optional[str] = None): self.name_regex = re.compile(name_regex) if name_regex else None @@ -307,6 +351,10 @@ def reverse(self, graph_module: "GraphModule") -> "GraphModule": class LinearToSerializedLinear(ReversibleTransformation): + """ + Transforms the linear layers matching name_regex to SerializedLinear layers. + """ + def __init__(self, name_regex: str): self.name_regex = re.compile(name_regex) if name_regex else None diff --git a/optimum/graphcore/fx/utils.py b/optimum/graphcore/fx/utils.py index 03e8e51f4..185133582 100644 --- a/optimum/graphcore/fx/utils.py +++ b/optimum/graphcore/fx/utils.py @@ -13,20 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import TYPE_CHECKING, Callable, Optional, List +from typing import Callable, List, Optional import torch import transformers -from transformers.utils.fx import HFTracer, check_if_model_is_supported, get_concrete_args, _gen_constructor_wrapper +from transformers.utils.fx import HFTracer, get_concrete_args from ..modeling_utils import PipelineMixin -if TYPE_CHECKING: - from transformers import PreTrainedModel - class PipelinedTracer(HFTracer): + """ + Tracer that enables tracing and transforming models to run them on IPUs. + Compared to the HFTracer, this one adds the following features: + - Ops can be wrapped (not only attributes of the torch module) to enable tracing. + - Each node contains the "parent_module_qualified_name" attribute, specifying under which module the node was + created. This is useful because some transformations need that, for instance RecomputationCheckpoint. + """ + def __init__(self, autowrap_modules=(math,), autowrap_functions=()): super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions) self.ops_to_wrap = [] @@ -129,14 +134,7 @@ def symbolic_trace_pipelined_model(pipelined_model: PipelineMixin) -> PipelineMi # Trick to make HFTracer._generate_dummy_input work with the pipelined class. # This attribute will be set properly in symbolic_trace_with_pipelined_tracer once tracing is done. pipelined_model.class_for_deserialization = transformers_class - - traced = symbolic_trace_with_pipelined_tracer( - pipelined_model, input_names=pipelined_model.input_names - ) - + traced = symbolic_trace_with_pipelined_tracer(pipelined_model, input_names=pipelined_model.input_names) type_ = type(f"Traced{pipelined_model.__class__.__name__}", (torch.fx.GraphModule, pipelined_model.__class__), {}) traced.__class__ = type_ - - # traced.ipu_config = pipelined_model.ipu_config - # traced._hooks = pipelined_model._hooks return traced diff --git a/optimum/graphcore/models/bert/modeling_bert.py b/optimum/graphcore/models/bert/modeling_bert.py index ec8b39938..15a07eb64 100644 --- a/optimum/graphcore/models/bert/modeling_bert.py +++ b/optimum/graphcore/models/bert/modeling_bert.py @@ -30,14 +30,15 @@ BertForSequenceClassification, BertForTokenClassification, ) - from transformers.utils.fx import _gen_constructor_wrapper -from ....fx.optimization import ChangeTrueDivToMulByInverse, FuseBiasInLinear, MergeLinears, compose +# from ....fx.optimization import ChangeTrueDivToMulByInverse, FuseBiasInLinear, MergeLinears, compose +from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose from ...fx.transformations import ( AddPoptorchBlock, AddPoptorchBlocksInSeries, ClipValues, + ClipValuesSymmetric, LinearToSerializedLinear, OutlineAttribute, RecomputationCheckpoint, @@ -45,20 +46,25 @@ VocabEmbeddingToSerializedEmbedding, ) from ...fx.utils import symbolic_trace_pipelined_model -from ...modeling_utils import ( - OnehotGather, - PipelineMixin, - SerializedLinear, - get_layer_ipu, - outline_attribute, - recomputation_checkpoint, - register, -) +from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register logger = logging.get_logger(__name__) +_OPTIMIZATION_TRANSFORMATIONS = [ + ChangeTrueDivToMulByInverse(), + MergeLinears(), + # FuseBiasInLinear(), +] + +_NON_REVERSIBLE_TRANSFORMATIONS = [ + ClipValuesSymmetric(1e4, exclude_targets=["view"]), + ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), + TupleOutput(), +] + + @register(BertForPreTraining) class PipelinedBertForPreTraining(BertForPreTraining, PipelineMixin): """ @@ -84,19 +90,22 @@ def get_transformations(self): log_insertions = self.ipu_config.log_insertions layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) transformations = [ - ChangeTrueDivToMulByInverse(), - MergeLinears(), - # FuseBiasInLinear(), - AddPoptorchBlock( - "Embedding", layer_ipu=0, module_name_regex="bert.embeddings", log_insertions=log_insertions - ), + AddPoptorchBlock("Embedding", 0, "bert.embeddings", log_insertions=log_insertions), OutlineAttribute("bert.embeddings.LayerNorm", "Embedding"), AddPoptorchBlocksInSeries( - "Encoder", layer_ipu, module_name_regex=r"bert.encoder.layer.[0-9]+", log_insertions=log_insertions + "Encoder", layer_ipu, r"bert.encoder.layer.[0-9]+", log_insertions=log_insertions ), AddPoptorchBlock("Pooler Output", 0, "bert.pooler", log_insertions=log_insertions), AddPoptorchBlock("Classifier Output", 0, "cls", log_insertions=log_insertions), ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations.append( + RecomputationCheckpoint( + "bert.encoder.layer.[0-9]+", to_exclude=f"bert.encoder.layer.{self.config.num_hidden_layers - 1}" + ) + ) + if self.ipu_config.embedding_serialization_factor > 1: + transformations.append(LinearToSerializedLinear("cls.predictions.decoder")) return transformations def parallelize(self): @@ -110,10 +119,9 @@ def parallelize(self): super().parallelize() traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - # if self.ipu_config.embedding_serialization_factor > 1: - # transformations.append(LinearToSerializedLinear("cls.predictions.decoder")) + transformations += _OPTIMIZATION_TRANSFORMATIONS composition = compose(*transformations) - non_reversible_composition = compose(ClipValues(1e4), TupleOutput()) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) traced = composition(traced) traced = non_reversible_composition(traced) return traced @@ -126,8 +134,7 @@ def deparallelize(self): """ super().deparallelize() transformations = self.get_transformations() - if self.ipu_config.embedding_serialization_factor > 1: - transformations.append(LinearToSerializedLinear("cls.predictions.decoder")) + transformations += _OPTIMIZATION_TRANSFORMATIONS composition = compose(*transformations) self = composition(self, reverse=True) return self @@ -242,20 +249,21 @@ def get_transformations(self): log_insertions = self.ipu_config.log_insertions layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) transformations = [ - AddPoptorchBlock( - "Embedding", layer_ipu=0, module_name_regex="bert.embeddings", log_insertions=log_insertions - ), + AddPoptorchBlock("Embedding", 0, "bert.embeddings", log_insertions=log_insertions), OutlineAttribute("bert.embeddings.LayerNorm", "Embedding"), AddPoptorchBlocksInSeries( - "Encoder", layer_ipu, module_name_regex=r"bert.encoder.layer.[0-9]+", log_insertions=log_insertions + "Encoder", layer_ipu, r"bert.encoder.layer.[0-9]+", log_insertions=log_insertions ), AddPoptorchBlock("Classifier Output", 0, "cls", log_insertions=log_insertions), ] if self.ipu_config.recompute_checkpoint_every_layer: - transformations.append(RecomputationCheckpoint("bert.encoder.layer.[0-9]+", to_exclude=f"bert.encoder.layer.{self.config.num_hidden_layers - 1}")) + transformations.append( + RecomputationCheckpoint( + "bert.encoder.layer.[0-9]+", to_exclude=f"bert.encoder.layer.{self.config.num_hidden_layers - 1}" + ) + ) if self.ipu_config.embedding_serialization_factor > 1: transformations.append(LinearToSerializedLinear("cls.predictions.decoder")) - transformations += [ChangeTrueDivToMulByInverse(), MergeLinears()] return transformations def parallelize(self): @@ -269,13 +277,11 @@ def parallelize(self): super().parallelize() traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS composition = compose(*transformations) - non_reversible_composition = compose(ClipValues(1e4), TupleOutput()) - + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) traced = composition(traced) traced = non_reversible_composition(traced) - import pdb; pdb.set_trace() - return traced def deparallelize(self): @@ -285,11 +291,10 @@ def deparallelize(self): compatible with the original model. """ super().deparallelize() - transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS composition = compose(*transformations) self = composition(self, reverse=True) - return self def forward( @@ -372,19 +377,23 @@ def get_transformations(self): layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) last_ipu = len(self.ipu_config.layers_per_ipu) - 1 transformations = [ - ChangeTrueDivToMulByInverse(), - MergeLinears(), - AddPoptorchBlock( - "Embedding", layer_ipu=0, module_name_regex="bert.embeddings", log_insertions=log_insertions - ), + AddPoptorchBlock("Embedding", 0, "bert.embeddings", log_insertions=log_insertions), OutlineAttribute("bert.embeddings.LayerNorm", "Embedding"), AddPoptorchBlocksInSeries( - "Encoder", layer_ipu, module_name_regex=r"bert.encoder.layer.[0-9]+", log_insertions=log_insertions + "Encoder", layer_ipu, r"bert.encoder.layer.[0-9]+", log_insertions=log_insertions ), # Only one of the following AddPoptorchBlock, will actually add a block. AddPoptorchBlock("Classifier Output", last_ipu, "classifier", log_insertions=log_insertions), AddPoptorchBlock("QA Outputs", last_ipu, "qa_outputs", log_insertions=log_insertions), ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations.append( + RecomputationCheckpoint( + "bert.encoder.layer.[0-9]+", to_exclude=f"bert.encoder.layer.{self.config.num_hidden_layers - 1}" + ) + ) + if self.ipu_config.embedding_serialization_factor > 1: + transformations.append(VocabEmbeddingToSerializedEmbedding()) return transformations @property @@ -400,26 +409,13 @@ def parallelize(self): - Adds recomputation checkpoints """ super().parallelize() - - # if self.ipu_config.recompute_checkpoint_every_layer: - # for layer in self.bert.encoder.layer[:-1]: - # h = recomputation_checkpoint(layer) - # self._hooks.append(h) - traced = symbolic_trace_pipelined_model(self) - transformations = self.get_transformations() - - if traced.ipu_config.embedding_serialization_factor > 1: - transformations.append(VocabEmbeddingToSerializedEmbedding()) - + transformations += _OPTIMIZATION_TRANSFORMATIONS composition = compose(*transformations) - - non_reversible_composition = compose(ClipValues(1e4), TupleOutput()) - + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) traced = composition(traced) traced = non_reversible_composition(traced) - return traced def deparallelize(self): @@ -429,17 +425,10 @@ def deparallelize(self): compatible with the original model. """ super().deparallelize() - transformations = self.get_transformations() - if self.ipu_config.embedding_serialization_factor > 1: - transformations.append(VocabEmbeddingToSerializedEmbedding()) - - # if self.ipu_config.recompute_checkpoint_every_layer: - # transformations.append(RecomputationCheckpoint()) - + transformations += _OPTIMIZATION_TRANSFORMATIONS composition = compose(*transformations) self = composition(self, reverse=True) - return self @@ -453,6 +442,7 @@ class PipelinedBertForSequenceClassification(BertForSequenceClassification, Bert model = PipelinedBertForSequenceClassification(config).parallelize().half() ``` """ + pass @@ -466,6 +456,7 @@ class PipelinedBertForMultipleChoice(BertForMultipleChoice, BertPipelineMixin): model = PipelinedBertForMultipleChoice(config).parallelize().half() ``` """ + pass @@ -479,6 +470,7 @@ class PipelinedBertForTokenClassification(BertForTokenClassification, BertPipeli model = PipelinedBertForTokenClassification(config).parallelize().half() ``` """ + pass diff --git a/optimum/graphcore/trainer.py b/optimum/graphcore/trainer.py index ed72b6557..d4d8db506 100644 --- a/optimum/graphcore/trainer.py +++ b/optimum/graphcore/trainer.py @@ -302,12 +302,15 @@ def __init__( self.eval_data_collator = data_collator_wrapper(self.eval_data_collator) self.model = to_pipelined(model, self.ipu_config, force=force_to_pipelined) - self.model = self.model.parallelize() + parallelized_from_training = self.model.parallelize() + self.model_for_eval = self.model.eval().parallelize() + self.model = parallelized_from_training self.original_model = model if not self.args.fp32: self.model = self.model.half() + self.model_for_eval = self.model_for_eval.half() self.training_model = None self.inference_model = None @@ -1323,9 +1326,10 @@ def _load_best_model(self): ) def _load_state_dict_in_model(self, state_dict): - self.model.deparallelize() + self.model = self.model.deparallelize() load_result = self.model.load_state_dict(state_dict, strict=False) - self.model.parallelize() + self.model = self.model.parallelize() + if not self.args.fp32: self.model.half() @@ -1842,7 +1846,7 @@ def predict( return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics) def _wrap_and_compile_model_for_evaluation(self, dataloader, prediction_loss_only): - model = self.wrap_model(self.model, training=False) + model = self.wrap_model(self.model_for_eval, training=False) self.compile_model(model, next(iter(dataloader)), log=True) return model From 87e9273fc2e44a7bd78ba1eee8ad09a1d08686be Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 8 Jul 2022 15:38:59 +0200 Subject: [PATCH 03/33] [WIP] fx formalism for BERT --- examples/language-modeling/ipu_config.json | 2 +- optimum/graphcore/fx/transformations.py | 29 ++++++++++++++++++- .../graphcore/models/bert/modeling_bert.py | 13 +++++++-- 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/examples/language-modeling/ipu_config.json b/examples/language-modeling/ipu_config.json index 3f9beb85b..43b9cc44e 100644 --- a/examples/language-modeling/ipu_config.json +++ b/examples/language-modeling/ipu_config.json @@ -1,5 +1,5 @@ { - "embedding_serialization_factor": 1, + "embedding_serialization_factor": 2, "recompute_checkpoint_every_layer": true, "optimizer_state_offchip": true, "replicated_tensor_sharding": true, diff --git a/optimum/graphcore/fx/transformations.py b/optimum/graphcore/fx/transformations.py index 2f6c988f2..1f80acbb5 100644 --- a/optimum/graphcore/fx/transformations.py +++ b/optimum/graphcore/fx/transformations.py @@ -375,7 +375,6 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule": serialized_linear.load_state_dict(linear.state_dict()) parent_fully_qualified_name, linear_name = node.target.rsplit(".", maxsplit=1) setattr(graph_module.get_submodule(parent_fully_qualified_name), linear_name, serialized_linear) - graph_module.tie_weights() return graph_module def reverse(self, graph_module: "GraphModule") -> "GraphModule": @@ -383,3 +382,31 @@ def reverse(self, graph_module: "GraphModule") -> "GraphModule": if node.op == "call_module" and isinstance(graph_module.get_submodule(node.target), SerializedLinear): graph_module.get_submodule(node.target).__class__ = torch.nn.Linear return graph_module + + +class TieWeights(Transformation): + def __init__(self, layer_a: str, layer_b: str, weight_attribute_name_for_a: Optional[str] = "weight", weight_attribute_name_for_b: Optional[str] = "weight"): + self.layer_a = layer_a + self.layer_b = layer_b + self.layer_b = layer_b + self.weight_attribute_name_for_a = weight_attribute_name_for_a + self.weight_attribute_name_for_b = weight_attribute_name_for_b + + def transform(self, graph_module: "GraphModule") -> "GraphModule": + layer_a, layer_b = None, None + for node in graph_module.graph.nodes: + if node.op == "call_module": + if node.target == self.layer_a: + layer_a = graph_module.get_submodule(node.target) + if node.target == self.layer_b: + layer_b = graph_module.get_submodule(node.target) + + if layer_a is None or layer_b is None: + raise ValueError(f"Could not find both layers {self.layer_a} and {self.layer_b} to tie their weights together") + if not hasattr(layer_a, self.weight_attribute_name_for_a): + raise AttributeError(f"{layer_a} does not have an attribute called {self.weight_attribute_name_for_a}") + if not hasattr(layer_b, self.weight_attribute_name_for_b): + raise AttributeError(f"{layer_b} does not have an attribute called {self.weight_attribute_name_for_b}") + + setattr(layer_b, self.weight_attribute_name_for_b, getattr(layer_a, self.weight_attribute_name_for_a)) + return graph_module diff --git a/optimum/graphcore/models/bert/modeling_bert.py b/optimum/graphcore/models/bert/modeling_bert.py index 15a07eb64..c1c6ba210 100644 --- a/optimum/graphcore/models/bert/modeling_bert.py +++ b/optimum/graphcore/models/bert/modeling_bert.py @@ -42,6 +42,7 @@ LinearToSerializedLinear, OutlineAttribute, RecomputationCheckpoint, + TieWeights, TupleOutput, VocabEmbeddingToSerializedEmbedding, ) @@ -105,7 +106,10 @@ def get_transformations(self): ) ) if self.ipu_config.embedding_serialization_factor > 1: - transformations.append(LinearToSerializedLinear("cls.predictions.decoder")) + transformations += [ + LinearToSerializedLinear("cls.predictions.decoder"), + TieWeights("bert.embeddings.word_embeddings", "cls.predictions.decoder"), + ] return transformations def parallelize(self): @@ -263,7 +267,10 @@ def get_transformations(self): ) ) if self.ipu_config.embedding_serialization_factor > 1: - transformations.append(LinearToSerializedLinear("cls.predictions.decoder")) + transformations += [ + LinearToSerializedLinear("cls.predictions.decoder"), + TieWeights("bert.embeddings.word_embeddings", "cls.predictions.decoder"), + ] return transformations def parallelize(self): @@ -411,7 +418,7 @@ def parallelize(self): super().parallelize() traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + # transformations += _OPTIMIZATION_TRANSFORMATIONS composition = compose(*transformations) non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) traced = composition(traced) From 7009e6aa31367df7d05ebc1a27131f9d869bb7be Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 21 Jul 2022 17:17:22 +0200 Subject: [PATCH 04/33] Temp --- optimum/graphcore/fx/transformations.py | 2 +- optimum/graphcore/models/bert/modeling_bert.py | 2 +- optimum/graphcore/trainer.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/optimum/graphcore/fx/transformations.py b/optimum/graphcore/fx/transformations.py index 1f80acbb5..8d5528526 100644 --- a/optimum/graphcore/fx/transformations.py +++ b/optimum/graphcore/fx/transformations.py @@ -271,7 +271,7 @@ def find_output_nodes_for_module_name(self, graph_module: "GraphModule", module_ else: break modules_from_the_past.add(name) - nodes_in_module = {n for n in nodes_in_module if set(n.users.keys()) & nodes_in_module} + # nodes_in_module = {n for n in nodes_in_module if set(n.users.keys()) & nodes_in_module} return [n for n in nodes_in_module if set(n.users.keys()) - nodes_in_module] def transform(self, graph_module: "GraphModule") -> "GraphModule": diff --git a/optimum/graphcore/models/bert/modeling_bert.py b/optimum/graphcore/models/bert/modeling_bert.py index c1c6ba210..d76346335 100644 --- a/optimum/graphcore/models/bert/modeling_bert.py +++ b/optimum/graphcore/models/bert/modeling_bert.py @@ -418,7 +418,7 @@ def parallelize(self): super().parallelize() traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - # transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += _OPTIMIZATION_TRANSFORMATIONS composition = compose(*transformations) non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) traced = composition(traced) diff --git a/optimum/graphcore/trainer.py b/optimum/graphcore/trainer.py index d4d8db506..ea0b049aa 100644 --- a/optimum/graphcore/trainer.py +++ b/optimum/graphcore/trainer.py @@ -310,7 +310,9 @@ def __init__( if not self.args.fp32: self.model = self.model.half() - self.model_for_eval = self.model_for_eval.half() + # inputs = {k: torch.ones(2, 23, dtype=torch.int64) for k in ["input_ids", "attention_mask", "token_type_ids"]} + # self.model_for_eval = self.model_for_eval.half() + import ipdb; ipdb.set_trace() self.training_model = None self.inference_model = None From d7df7b501962fd82e7251779f0068de3d5db2805 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 3 Aug 2022 15:36:58 +0200 Subject: [PATCH 05/33] Trainer with symbolically traced models --- optimum/graphcore/fx/transformations.py | 12 ++- optimum/graphcore/fx/utils.py | 13 ++- optimum/graphcore/modeling_utils.py | 14 ++- .../graphcore/models/bert/modeling_bert.py | 4 - optimum/graphcore/models/t5/modeling_t5.py | 2 +- optimum/graphcore/trainer.py | 86 ++++++++++++------- optimum/graphcore/trainer_seq2seq.py | 6 -- 7 files changed, 87 insertions(+), 50 deletions(-) diff --git a/optimum/graphcore/fx/transformations.py b/optimum/graphcore/fx/transformations.py index 8d5528526..9096f3bc8 100644 --- a/optimum/graphcore/fx/transformations.py +++ b/optimum/graphcore/fx/transformations.py @@ -385,7 +385,13 @@ def reverse(self, graph_module: "GraphModule") -> "GraphModule": class TieWeights(Transformation): - def __init__(self, layer_a: str, layer_b: str, weight_attribute_name_for_a: Optional[str] = "weight", weight_attribute_name_for_b: Optional[str] = "weight"): + def __init__( + self, + layer_a: str, + layer_b: str, + weight_attribute_name_for_a: Optional[str] = "weight", + weight_attribute_name_for_b: Optional[str] = "weight", + ): self.layer_a = layer_a self.layer_b = layer_b self.layer_b = layer_b @@ -402,7 +408,9 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule": layer_b = graph_module.get_submodule(node.target) if layer_a is None or layer_b is None: - raise ValueError(f"Could not find both layers {self.layer_a} and {self.layer_b} to tie their weights together") + raise ValueError( + f"Could not find both layers {self.layer_a} and {self.layer_b} to tie their weights together" + ) if not hasattr(layer_a, self.weight_attribute_name_for_a): raise AttributeError(f"{layer_a} does not have an attribute called {self.weight_attribute_name_for_a}") if not hasattr(layer_b, self.weight_attribute_name_for_b): diff --git a/optimum/graphcore/fx/utils.py b/optimum/graphcore/fx/utils.py index 185133582..c149e8998 100644 --- a/optimum/graphcore/fx/utils.py +++ b/optimum/graphcore/fx/utils.py @@ -126,15 +126,22 @@ def symbolic_trace_pipelined_model(pipelined_model: PipelineMixin) -> PipelineMi return pipelined_model transformers_class = None - for base in pipelined_model.__class__.__bases__: - if transformers.PreTrainedModel in base.__mro__: + bases = list(pipelined_model.__class__.__bases__) + import inspect + + while bases: + base = bases.pop(0) + if inspect.getmodule(base).__name__.startswith("transformers") and transformers.PreTrainedModel in base.mro(): transformers_class = base break + bases += list(base.__bases__) # Trick to make HFTracer._generate_dummy_input work with the pipelined class. # This attribute will be set properly in symbolic_trace_with_pipelined_tracer once tracing is done. pipelined_model.class_for_deserialization = transformers_class - traced = symbolic_trace_with_pipelined_tracer(pipelined_model, input_names=pipelined_model.input_names) + traced = symbolic_trace_with_pipelined_tracer( + pipelined_model, input_names=pipelined_model.input_names_for_symbolic_trace + ) type_ = type(f"Traced{pipelined_model.__class__.__name__}", (torch.fx.GraphModule, pipelined_model.__class__), {}) traced.__class__ = type_ return traced diff --git a/optimum/graphcore/modeling_utils.py b/optimum/graphcore/modeling_utils.py index c16e3b2f1..e749eae2b 100644 --- a/optimum/graphcore/modeling_utils.py +++ b/optimum/graphcore/modeling_utils.py @@ -148,8 +148,18 @@ def get_transformations(self): raise NotImplementedError("You need to implement get_transformations.") @property - def input_names(self): - return list(inspect.signature(self.forward).parameters.keys()) + def input_names_for_symbolic_trace(self): + # input_names_attribute = "_input_names_for_symbolic_trace" if self.training else "_eval_input_names_for_symbolic_trace" + input_names_attribute = "_input_names_for_symbolic_trace" + if not hasattr(self, input_names_attribute): + setattr(self, input_names_attribute, list(inspect.signature(self.forward).parameters.keys())) + return getattr(self, input_names_attribute) + + @input_names_for_symbolic_trace.setter + def input_names_for_symbolic_trace(self, input_names: List[str]): + # input_names_attribute = "_input_names_for_symbolic_trace" if self.training else "_eval_input_names_for_symbolic_trace" + input_names_attribute = "_input_names_for_symbolic_trace" + setattr(self, input_names_attribute, input_names) def parallelize(self): """Transforms the model to run in an IPU pipeline.""" diff --git a/optimum/graphcore/models/bert/modeling_bert.py b/optimum/graphcore/models/bert/modeling_bert.py index d76346335..4c7167885 100644 --- a/optimum/graphcore/models/bert/modeling_bert.py +++ b/optimum/graphcore/models/bert/modeling_bert.py @@ -403,10 +403,6 @@ def get_transformations(self): transformations.append(VocabEmbeddingToSerializedEmbedding()) return transformations - @property - def input_names(self): - return ["input_ids", "attention_mask", "token_type_ids", "labels"] - def parallelize(self): """ Transform the model to run in an IPU pipeline. diff --git a/optimum/graphcore/models/t5/modeling_t5.py b/optimum/graphcore/models/t5/modeling_t5.py index 8a93dc476..eac8adc13 100644 --- a/optimum/graphcore/models/t5/modeling_t5.py +++ b/optimum/graphcore/models/t5/modeling_t5.py @@ -425,7 +425,7 @@ def forward( if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.model_dim**-0.5) + sequence_output = sequence_output * (self.model_dim ** -0.5) lm_scale_modifier = getattr(self, "lm_scale_modifier", None) if lm_scale_modifier is not None: diff --git a/optimum/graphcore/trainer.py b/optimum/graphcore/trainer.py index ea0b049aa..63ac242dd 100644 --- a/optimum/graphcore/trainer.py +++ b/optimum/graphcore/trainer.py @@ -302,17 +302,12 @@ def __init__( self.eval_data_collator = data_collator_wrapper(self.eval_data_collator) self.model = to_pipelined(model, self.ipu_config, force=force_to_pipelined) - parallelized_from_training = self.model.parallelize() - self.model_for_eval = self.model.eval().parallelize() - self.model = parallelized_from_training + self.model_for_eval = self.model.eval() self.original_model = model if not self.args.fp32: self.model = self.model.half() - # inputs = {k: torch.ones(2, 23, dtype=torch.int64) for k in ["input_ids", "attention_mask", "token_type_ids"]} - # self.model_for_eval = self.model_for_eval.half() - import ipdb; ipdb.set_trace() self.training_model = None self.inference_model = None @@ -453,35 +448,52 @@ def pytorch_optimizer_to_poptorch( def compile_model( self, - model: poptorch.PoplarExecutor, sample_batch: Union[Dict[str, torch.Tensor], Tuple[torch.Tensor]], + training: bool, log: bool = False, ): """ Compiles the model with PopTorch. Args: - model (`poptorch.PoplarExecutor`): - The model to compile (already wrapped). sample_batch (`Dict[str, torch.Tensor]` or `Tuple[torch.Tensor]`): The inputs to use the compilation, this will set the input shapes that the compiled model can accept. + training (`bool`): + Whether to compile the model for training or not. log (`bool`, *optional*, defaults to `False`): Whether to log that compilation is happening or not. + + Returns: + `poptorch.PoplarExecutor`: The compiled model. """ # Skipping compilation if the model was already compiled. - if model.isCompiled(): - return - if log: - logger.info("Compiling Model...") - sample_batch = self._prepare_inputs(sample_batch) - start_compile = time.perf_counter() - if isinstance(sample_batch, tuple): - model.compile(*sample_batch) + if training and self.training_model is not None: + return self.training_model + elif not training and self.inference_model is not None: + return self.inference_model else: - model.compile(**sample_batch) - duration_compilation = time.perf_counter() - start_compile - if log: - logger.info(f"Compiled/Loaded model in {duration_compilation} secs") + sample_batch = self._prepare_inputs(sample_batch) + model = self.model if training else self.model_for_eval + model.input_names_for_symbolic_trace = list(sample_batch.keys()) + model = model.parallelize() + if not self.args.fp32: + model.half() + if training: + self.model = model + else: + self.model_for_eval = model + model = self._wrap_model(model, training=training) + if log: + logger.info("Compiling Model...") + start_compile = time.perf_counter() + if isinstance(sample_batch, tuple): + model.compile(*sample_batch) + else: + model.compile(**sample_batch) + duration_compilation = time.perf_counter() - start_compile + if log: + logger.info(f"Compiled/Loaded model in {duration_compilation} secs") + return model def add_callback(self, callback): """ @@ -791,6 +803,10 @@ def create_optimizer(self): Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ if self.optimizer is None: + if not isinstance(self.model, torch.fx.GraphModule): + warnings.warn( + "The model seems to not have been parallelized, this might lead to unsuspected behaviour and/or failure with the optimizer." + ) decay_parameters = get_parameter_names(self.model, [nn.LayerNorm]) decay_parameters = {name for name in decay_parameters if "bias" not in name} if self.args.lamb or self.args.lamb_no_bias_correction: @@ -884,25 +900,27 @@ def num_examples(self, dataloader: poptorch.DataLoader) -> int: """ return len(dataloader.dataset) - def wrap_model(self, model: Union[PreTrainedModel, PoplarExecutor], training=True) -> PoplarExecutor: +def wrap_model(self, model: Union[PreTrainedModel, PoplarExecutor], training: bool =True) -> PoplarExecutor: """ Wraps a model for PopTorch, either for training or for inference. Args: - model ([`transformers.PreTrainedModel`] or `poptorch.PoplarExecutor`): + model ([`~transformers.modeling_utils.PreTrainedModel`] or `poptorch.PoplarExecutor`): The model to wrap. training (`bool`, *optional*, defaults to `True`): Whether to wrap the model for training or not. Returns: `poptorch.PoplarExecutor`: The wrapped model. - """ wrapped = None if isinstance(model, PoplarExecutor): wrapped = model elif training: if self.training_model is None: + # Creating the optimizer if it was not already created. This is needed because the optimizer model + # parameters must be exactly the same as poptorch.trainingModel parameters. + self.create_optimizer() self.training_model = poptorch.trainingModel( model.train(), options=self.opts, optimizer=self.optimizer ) @@ -1064,6 +1082,12 @@ def _inner_training_loop( if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: debug_overflow = DebugUnderflowOverflow(self.model) # noqa + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + model = self._compile_model(next(iter(train_dataloader)), training=True, log=True) + self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = IPUTrainerState() @@ -1072,6 +1096,7 @@ def _inner_training_loop( trial = None self.state.is_hyper_param_search = trial is not None + # TODO: brought by sdk3.0 pr self.training_model = self.wrap_model(self.model) # TODO: handle optimizer and scheduler creation @@ -1081,6 +1106,7 @@ def _inner_training_loop( # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) + # TODO: brought by sdk3.0 pr self.compile_model(self.training_model, next(iter(train_dataloader)), log=True) # Train! @@ -1328,7 +1354,8 @@ def _load_best_model(self): ) def _load_state_dict_in_model(self, state_dict): - self.model = self.model.deparallelize() + if isinstance(self.model, torch.fx.GraphModule): + self.model = self.model.deparallelize() load_result = self.model.load_state_dict(state_dict, strict=False) self.model = self.model.parallelize() @@ -1743,7 +1770,7 @@ def evaluate( # Running this here (even though it is being recalled in self.evaluation_loop to make compilation happen here. # That way, compilation will not mess inference speed metrics. - _ = self._wrap_and_compile_model_for_evaluation(eval_dataloader, prediction_loss_only) + _ = self._compile(self.model_for_eval, eval_dataloader, training=False) start_time = time.time() @@ -1847,11 +1874,6 @@ def predict( return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics) - def _wrap_and_compile_model_for_evaluation(self, dataloader, prediction_loss_only): - model = self.wrap_model(self.model_for_eval, training=False) - self.compile_model(model, next(iter(dataloader)), log=True) - return model - def evaluation_loop( self, dataloader: poptorch.DataLoader, @@ -1869,7 +1891,7 @@ def evaluation_loop( prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only ) - self.inference_model = self._wrap_and_compile_model_for_evaluation(dataloader, prediction_loss_only) + model = self._compile_model(next(iter(dataloader)), training=False) batch_size = dataloader.batch_size diff --git a/optimum/graphcore/trainer_seq2seq.py b/optimum/graphcore/trainer_seq2seq.py index 14faef19f..daf2dca2b 100644 --- a/optimum/graphcore/trainer_seq2seq.py +++ b/optimum/graphcore/trainer_seq2seq.py @@ -28,12 +28,6 @@ class IPUSeq2SeqTrainer(IPUTrainer): - def _wrap_and_compile_model_for_evaluation(self, dataloader, prediction_loss_only): - if prediction_loss_only: - return super()._wrap_and_compile_model_for_evaluation(dataloader, prediction_loss_only) - self.model.compile_for_generate(next(iter(dataloader)), self.args.generation_num_beams) - return self.model - def evaluate( self, eval_dataset: Optional[Dataset] = None, From 29e48ecbc4909debcfb740b3628e3cc6cba3589b Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 3 Aug 2022 17:31:41 +0200 Subject: [PATCH 06/33] Fix issues --- optimum/graphcore/models/t5/modeling_t5.py | 2 +- optimum/graphcore/trainer.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/optimum/graphcore/models/t5/modeling_t5.py b/optimum/graphcore/models/t5/modeling_t5.py index eac8adc13..8a93dc476 100644 --- a/optimum/graphcore/models/t5/modeling_t5.py +++ b/optimum/graphcore/models/t5/modeling_t5.py @@ -425,7 +425,7 @@ def forward( if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.model_dim ** -0.5) + sequence_output = sequence_output * (self.model_dim**-0.5) lm_scale_modifier = getattr(self, "lm_scale_modifier", None) if lm_scale_modifier is not None: diff --git a/optimum/graphcore/trainer.py b/optimum/graphcore/trainer.py index 63ac242dd..b5399235e 100644 --- a/optimum/graphcore/trainer.py +++ b/optimum/graphcore/trainer.py @@ -474,7 +474,11 @@ def compile_model( else: sample_batch = self._prepare_inputs(sample_batch) model = self.model if training else self.model_for_eval - model.input_names_for_symbolic_trace = list(sample_batch.keys()) + if isinstance(sample_batch, tuple): + signature = inspect.signature(model.forward) + model.input_names_for_symbolic_trace = list(signature.parameters.keys())[: len(sample_batch)] + else: + model.input_names_for_symbolic_trace = list(sample_batch.keys()) model = model.parallelize() if not self.args.fp32: model.half() @@ -1770,7 +1774,7 @@ def evaluate( # Running this here (even though it is being recalled in self.evaluation_loop to make compilation happen here. # That way, compilation will not mess inference speed metrics. - _ = self._compile(self.model_for_eval, eval_dataloader, training=False) + _ = self._compile_model(next(iter(eval_dataloader)), training=False) start_time = time.time() From 1ac1e9f1627b737a176d79cf648c83282a284fb8 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 5 Aug 2022 18:43:45 +0200 Subject: [PATCH 07/33] FX parallelize for T5 --- optimum/graphcore/fx/transformations.py | 105 ++++++- .../graphcore/models/bert/modeling_bert.py | 1 - optimum/graphcore/models/t5/modeling_t5.py | 294 ++++++++---------- 3 files changed, 219 insertions(+), 181 deletions(-) diff --git a/optimum/graphcore/fx/transformations.py b/optimum/graphcore/fx/transformations.py index 9096f3bc8..de8710e68 100644 --- a/optimum/graphcore/fx/transformations.py +++ b/optimum/graphcore/fx/transformations.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections +import operator import re from typing import TYPE_CHECKING, Callable, List, Optional, Union @@ -22,7 +23,7 @@ from optimum.utils import logging from ...fx.optimization import ReversibleTransformation, Transformation -from ..modeling_utils import SerializedEmbedding, SerializedLinear +from ..modeling_utils import SerializedEmbedding, SerializedLinear, SharedEmbedding if TYPE_CHECKING: @@ -38,7 +39,8 @@ def node_matches_pattern(pattern, node: "Node"): def parent_module_qualified_name(node: "Node") -> str: - return getattr(node, "parent_module_qualified_name", "") + name = getattr(node, "parent_module_qualified_name", "") + return name if name != "root" else "" class AddPoptorchBlockBase(ReversibleTransformation): @@ -58,7 +60,11 @@ def find_start_nodes(self, graph_module: "GraphModule") -> List["Node"]: nodes = [] prefixes = set() for node in graph_module.graph.nodes: - match = re.match(self.name_regex, parent_module_qualified_name(node)) + # If module under which the node was created is root, we use the node name to match. + name = node.name + if parent_module_qualified_name(node) != "": + name = f"{parent_module_qualified_name(node)}.{name}" + match = re.match(self.name_regex, name) if match: prefix = match.group(0) if prefix not in prefixes: @@ -328,11 +334,18 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule": raise RuntimeError("Could not find any embedding node") embedding_node = max(embedding_nodes, key=lambda node: graph_module.get_submodule(node.target).num_embeddings) - parent_fully_qualified_name, embedding_name = embedding_node.target.rsplit(".", maxsplit=1) + split = embedding_node.target.rsplit(".", maxsplit=1) + if len(split) == 1: + split = [None] + split + parent_fully_qualified_name, embedding_name = split + new_embedding = SerializedEmbedding( graph_module.get_submodule(embedding_node.target), graph_module.ipu_config.embedding_serialization_factor ) - setattr(graph_module.get_submodule(parent_fully_qualified_name), embedding_name, new_embedding) + submodule = graph_module + if parent_fully_qualified_name is not None: + submodule = graph_module.get_submodule(parent_fully_qualified_name) + setattr(submodule, embedding_name, new_embedding) embedding_node.was_transformed = "VocabEmbeddingToSerializedEmbedding" return graph_module @@ -340,9 +353,15 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule": def reverse(self, graph_module: "GraphModule") -> "GraphModule": for node in graph_module.graph.nodes: if getattr(node, "was_transformed", "") == "VocabEmbeddingToSerializedEmbedding": - parent_fully_qualified_name, embedding_name = node.target.rsplit(".", maxsplit=1) + split = node.target.rsplit(".", maxsplit=1) + if len(split) == 1: + split = [None] + split + parent_fully_qualified_name, embedding_name = split + submodule = graph_module + if parent_fully_qualified_name is not None: + submodule = graph_module.get_submodule(parent_fully_qualified_name) setattr( - graph_module.get_submodule(parent_fully_qualified_name), + submodule, embedding_name, graph_module.get_submodule(node.target).deserialize(), ) @@ -373,8 +392,14 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule": mode=poptorch.MatMulSerializationMode.OutputChannels, ) serialized_linear.load_state_dict(linear.state_dict()) - parent_fully_qualified_name, linear_name = node.target.rsplit(".", maxsplit=1) - setattr(graph_module.get_submodule(parent_fully_qualified_name), linear_name, serialized_linear) + split = node.target.rsplit(".", maxsplit=1) + if len(split) == 1: + split = [None] + split + parent_fully_qualified_name, linear_name = split + submodule = graph_module + if parent_fully_qualified_name is not None: + submodule = graph_module.get_submodule(parent_fully_qualified_name) + setattr(submodule, linear_name, serialized_linear) return graph_module def reverse(self, graph_module: "GraphModule") -> "GraphModule": @@ -418,3 +443,65 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule": setattr(layer_b, self.weight_attribute_name_for_b, getattr(layer_a, self.weight_attribute_name_for_a)) return graph_module + + +class ShareEmbeddingComputation(Transformation): + def _find_nodes_to_move(self, graph_module, embedding_input_node): + to_visit = [embedding_input_node] + to_move = set() + while to_visit: + node = to_visit.pop(0) + if node.op != "placeholder": + to_move.add(node) + to_visit += node.all_input_nodes + ordered_to_move = [] + for node in graph_module.graph.nodes: + if node in to_move: + ordered_to_move.append(node) + return ordered_to_move + + def _move_nodes_after_node(self, graph_module, nodes_to_move, node): + old_to_new_mapping = {} + with graph_module.graph.inserting_after(node): + for n in reversed(nodes_to_move): + old_to_new_mapping[n] = graph_module.graph.create_node(n.op, n.target, n.args, n.kwargs, n.name) + return old_to_new_mapping + + def transform(self, graph_module: "GraphModule") -> "GraphModule": + candidates = collections.defaultdict(list) + embedding_nodes = collections.defaultdict(list) + for node in graph_module.graph.nodes: + if node.op == "call_module" and isinstance(graph_module.get_submodule(node.target), torch.nn.Embedding): + candidates[node.target].append(node.args[0]) + embedding_nodes[node.target].append(node) + + candidates = {k: v for k, v in candidates.items() if len(v) > 1} + embedding_nodes = {k: v for k, v in embedding_nodes.items() if k in candidates} + + for target, embedding_input_nodes in candidates.items(): + if len(embedding_input_nodes) > 2: + raise NotImplementedError("Currently support embedding computation sharing for 2.") + new_input_nodes = [] + for input_node in reversed(embedding_input_nodes[1:]): + nodes_to_move = self._find_nodes_to_move(graph_module, input_node) + old_to_new_mapping = self._move_nodes_after_node(graph_module, nodes_to_move, embedding_input_nodes[0]) + for old_node, new_node in old_to_new_mapping.items(): + old_node.replace_all_uses_with(new_node) + graph_module.graph.erase_node(old_node) + new_input_nodes.append(old_to_new_mapping[nodes_to_move[-1]]) + + graph_module.add_submodule(target, SharedEmbedding(graph_module.get_submodule(target))) + shared_node = embedding_nodes[target][0] + shared_node.args = tuple(embedding_input_nodes[0:1] + new_input_nodes) + with graph_module.graph.inserting_after(shared_node): + getitem = graph_module.graph.call_function(operator.getitem, ()) + shared_node.replace_all_uses_with(getitem) + getitem.args = (shared_node, 0) + with graph_module.graph.inserting_after(getitem): + for idx in reversed(range(len(embedding_nodes[target][1:]))): + embedding_node = embedding_nodes[target][idx + 1] + getitem = graph_module.graph.call_function(operator.getitem, (shared_node, idx + 1)) + embedding_node.replace_all_uses_with(getitem) + graph_module.graph.erase_node(embedding_node) + + return graph_module diff --git a/optimum/graphcore/models/bert/modeling_bert.py b/optimum/graphcore/models/bert/modeling_bert.py index 4c7167885..ebb0d89e5 100644 --- a/optimum/graphcore/models/bert/modeling_bert.py +++ b/optimum/graphcore/models/bert/modeling_bert.py @@ -32,7 +32,6 @@ ) from transformers.utils.fx import _gen_constructor_wrapper -# from ....fx.optimization import ChangeTrueDivToMulByInverse, FuseBiasInLinear, MergeLinears, compose from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose from ...fx.transformations import ( AddPoptorchBlock, diff --git a/optimum/graphcore/models/t5/modeling_t5.py b/optimum/graphcore/models/t5/modeling_t5.py index 8a93dc476..71daf9e11 100644 --- a/optimum/graphcore/models/t5/modeling_t5.py +++ b/optimum/graphcore/models/t5/modeling_t5.py @@ -17,7 +17,6 @@ import torch import torch.nn as nn -from torch import Tensor import poptorch from optimum.utils import logging @@ -25,9 +24,25 @@ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from transformers.models.t5.modeling_t5 import __HEAD_MASK_WARNING_MSG, T5Block, T5Stack +from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ...fx.transformations import ( + AddPoptorchBlock, + AddPoptorchBlocksInSeries, + ClipValues, + ClipValuesSymmetric, + LinearToSerializedLinear, + OutlineAttribute, + RecomputationCheckpoint, + ShareEmbeddingComputation, + TieWeights, + TupleOutput, + VocabEmbeddingToSerializedEmbedding, +) +from ...fx.utils import symbolic_trace_pipelined_model from ...generation_utils import IPUGenerationMixin from ...modeling_utils import ( GenerationMethodsMixin, + OnehotGather, PipelineMixin, SerializedLinear, SharedEmbedding, @@ -39,129 +54,17 @@ logger = logging.get_logger(__name__) +_OPTIMIZATION_TRANSFORMATIONS = [ + ChangeTrueDivToMulByInverse(), + MergeLinears(), + # FuseBiasInLinear(), +] -class CustomT5Block(T5Block): - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - encoder_decoder_position_bias=None, - layer_head_mask=None, - cross_attn_layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - return_dict=True, - ): - - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - - self_attention_outputs = self.layer[0]( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states, present_key_value_state = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights - - # clamp inf values to enable fp16 training - # Custom: Remove check for inf - if hidden_states.dtype == torch.float16: - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - - do_cross_attention = self.is_decoder and encoder_hidden_states is not None - if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - - cross_attention_outputs = self.layer[1]( - hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - position_bias=encoder_decoder_position_bias, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = cross_attention_outputs[0] - - # clamp inf values to enable fp16 training - # Custom: Remove check for inf - if hidden_states.dtype == torch.float16: - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - - # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] - - # Apply Feed Forward layer - hidden_states = self.layer[-1](hidden_states) - - # clamp inf values to enable fp16 training - # Custom: Remove check for inf - if hidden_states.dtype == torch.float16: - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - - outputs = (hidden_states,) - - if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs - else: - outputs = outputs + attention_outputs - - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - - -class CustomT5Stack(T5Stack): - def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: - if encoder_attention_mask.dim() == 3: - encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] - if encoder_attention_mask.dim() == 2: - encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] - # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition - # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow - # /transformer/transformer_layers.py#L270 - # encoder_extended_attention_mask = (encoder_extended_attention_mask == - # encoder_extended_attention_mask.transpose(-1, -2)) - encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility - - # Always use -1e4 to avoid NaN issues. - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4 - return encoder_extended_attention_mask +_NON_REVERSIBLE_TRANSFORMATIONS = [ + ClipValuesSymmetric(1e4, exclude_targets=["view"]), + ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), + TupleOutput(), +] @register(T5ForConditionalGeneration) @@ -229,6 +132,40 @@ def scale_down_weights(self, factor: float = 1, restore: bool = False): if not restore: self.lm_scale_modifier /= emb_scaling + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + transformations = [ + AddPoptorchBlock("Embedding", 0, "encoder.shared", log_insertions=log_insertions), + AddPoptorchBlocksInSeries( + "Encoder", layer_ipu[: self.config.num_layers], r"encoder.block.[0-9]+", log_insertions=log_insertions + ), + AddPoptorchBlocksInSeries( + "Decoder", + layer_ipu[self.config.num_layers - 1 :], + r"decoder.block.[0-9]+", + log_insertions=log_insertions, + ), + AddPoptorchBlock("LM Head Output", 0, "lm_head", log_insertions=log_insertions), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations += [ + RecomputationCheckpoint( + "encoder.block.[0-9]+", to_exclude=f"encoder.block.{self.config.num_layers - 1}" + ), + RecomputationCheckpoint( + "decoder.block.[0-9]+", to_exclude=f"decoder.block.{self.config.num_layers - 1}" + ), + ] + + if self.ipu_config.embedding_serialization_factor > 1: + transformations += [ + LinearToSerializedLinear("lm_head"), + TieWeights("shared", "lm_head"), + ] + transformations += [ShareEmbeddingComputation()] + return transformations + def parallelize(self): """ Transform the model to run in an IPU pipeline. @@ -241,32 +178,47 @@ def parallelize(self): model = PipelinedT5ForConditionalGeneration(config).parallelize().half() ``` """ - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - - logger.info("-------------------- Device Allocation --------------------") - logger.info("Embedding --> IPU 0") - - if self.ipu_config.embedding_serialization_factor > 1: - serialized_lm_head = SerializedLinear( - self.config.d_model, - self.shared.num_embeddings, - self.ipu_config.embedding_serialization_factor, - bias=False, - mode=poptorch.MatMulSerializationMode.OutputChannels, - ) - serialized_lm_head.load_state_dict(self.lm_head.state_dict()) - self.lm_head = serialized_lm_head - # TODO: is it needed to check? - if self.config.tie_word_embeddings: - self.tie_weights() + PipelineMixin.parallelize(self) + for mod in self.modules(): + if isinstance(mod, T5LayerNorm): + mod.forward = poptorch.autocast(enabled=True)(mod.forward) + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + # transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + traced = composition(traced) + traced = non_reversible_composition(traced) + import ipdb + + ipdb.set_trace() + return traced + # layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + + # logger.info("-------------------- Device Allocation --------------------") + # logger.info("Embedding --> IPU 0") + + # if self.ipu_config.embedding_serialization_factor > 1: + # serialized_lm_head = SerializedLinear( + # self.config.d_model, + # self.shared.num_embeddings, + # self.ipu_config.embedding_serialization_factor, + # bias=False, + # mode=poptorch.MatMulSerializationMode.OutputChannels, + # ) + # serialized_lm_head.load_state_dict(self.lm_head.state_dict()) + # self.lm_head = serialized_lm_head + # # TODO: is it needed to check? + # if self.config.tie_word_embeddings: + # self.tie_weights() # self.scale_down_weights(factor=1) - self.encoder_and_decoder_embeddings_computation(True) - self.shared = poptorch.BeginBlock(self.shared, "Embedding", ipu_id=0) + # self.encoder_and_decoder_embeddings_computation(True) + # self.shared = poptorch.BeginBlock(self.shared, "Embedding", ipu_id=0) # Use a custom T5Stack implementation because sharing the position bias causes OOM error - self.encoder.__class__ = CustomT5Stack - self.decoder.__class__ = CustomT5Stack + # self.encoder.__class__ = CustomT5Stack + # self.decoder.__class__ = CustomT5Stack # Use a custom T5Block implementation that removes a dynamic if blocks that can't be statically traced for block in self.encoder.block: @@ -274,32 +226,32 @@ def parallelize(self): for block in self.decoder.block: block.__class__ = CustomT5Block - for index, layer in enumerate(self.encoder.block): - ipu = layer_ipu[index] - if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1: - recomputation_checkpoint(layer) - self.encoder.block[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) - logger.info(f"Encoder {index:<2} --> IPU {ipu}") - - self.encoder.final_layer_norm = poptorch.BeginBlock( - self.encoder.final_layer_norm, "Encoder Stack Final LayerNorm", ipu_id=ipu - ) - - shift = len(self.encoder.block) - for index, layer in enumerate(self.decoder.block): - ipu = layer_ipu[index + shift] - if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1: - recomputation_checkpoint(layer) - self.decoder.block[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu) - logger.info(f"Decoder {index:<2} --> IPU {ipu}") - - self.decoder.final_layer_norm = poptorch.BeginBlock( - self.decoder.final_layer_norm, "Decoder Stack Final LayerNorm", ipu_id=ipu - ) - - logger.info("LM Head Output --> IPU 0") - self.lm_head = poptorch.BeginBlock(self.lm_head, "LM Head Output", ipu_id=0) - logger.info("-----------------------------------------------------------") + # for index, layer in enumerate(self.encoder.block): + # ipu = layer_ipu[index] + # if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1: + # recomputation_checkpoint(layer) + # self.encoder.block[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) + # logger.info(f"Encoder {index:<2} --> IPU {ipu}") + + # self.encoder.final_layer_norm = poptorch.BeginBlock( + # self.encoder.final_layer_norm, "Encoder Stack Final LayerNorm", ipu_id=ipu + # ) + + # shift = len(self.encoder.block) + # for index, layer in enumerate(self.decoder.block): + # ipu = layer_ipu[index + shift] + # if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1: + # recomputation_checkpoint(layer) + # self.decoder.block[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu) + # logger.info(f"Decoder {index:<2} --> IPU {ipu}") + + # self.decoder.final_layer_norm = poptorch.BeginBlock( + # self.decoder.final_layer_norm, "Decoder Stack Final LayerNorm", ipu_id=ipu + # ) + + # logger.info("LM Head Output --> IPU 0") + # self.lm_head = poptorch.BeginBlock(self.lm_head, "LM Head Output", ipu_id=0) + # logger.info("-----------------------------------------------------------") return self def deparallelize(self): From 5cbb50e0ce600d237ade5adac1dce54a126e3a6c Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 10 Aug 2022 15:41:53 +0200 Subject: [PATCH 08/33] [WIP] BART --- .../graphcore/models/bart/modeling_bart.py | 692 +++--------------- .../models/hubert/modeling_hubert.py | 91 ++- optimum/graphcore/models/t5/modeling_t5.py | 11 +- optimum/graphcore/models/vit/modeling_vit.py | 78 +- optimum/graphcore/trainer.py | 3 +- 5 files changed, 224 insertions(+), 651 deletions(-) diff --git a/optimum/graphcore/models/bart/modeling_bart.py b/optimum/graphcore/models/bart/modeling_bart.py index b521de29b..524a4db0d 100644 --- a/optimum/graphcore/models/bart/modeling_bart.py +++ b/optimum/graphcore/models/bart/modeling_bart.py @@ -11,47 +11,54 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import random -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple import torch import torch.nn as nn -import poptorch -from optimum.utils import logging -from transformers import BartForConditionalGeneration, BartForSequenceClassification, BartModel -from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, - Seq2SeqLMOutput, - Seq2SeqModelOutput, - Seq2SeqSequenceClassifierOutput, -) -from transformers.models.bart.modeling_bart import ( - BartAttention, - BartDecoder, - BartEncoder, - BartEncoderLayer, - shift_tokens_right, -) +import transformers +from transformers import BartForConditionalGeneration +from transformers.models.bart.modeling_bart import BartAttention +from optimum.utils import logging +from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, ReversibleTransformation, compose from ...generation_utils import IPUGenerationMixin from ...modeling_utils import ( GenerationMethodsMixin, PipelineMixin, - SerializedLinear, - SharedEmbedding, get_layer_ipu, - recomputation_checkpoint, register, ) +from ...fx.utils import symbolic_trace_pipelined_model +from ...fx.transformations import ( + AddPoptorchBlock, + AddPoptorchBlocksInSeries, + ClipValues, + ClipValuesSymmetric, + LinearToSerializedLinear, + RecomputationCheckpoint, + ShareEmbeddingComputation, + TieWeights, + TupleOutput, +) logger = logging.get_logger(__name__) FLOAT16_LIMIT = 1e4 +_OPTIMIZATION_TRANSFORMATIONS = [ + ChangeTrueDivToMulByInverse(), + MergeLinears(), + # FuseBiasInLinear(), +] + +_NON_REVERSIBLE_TRANSFORMATIONS = [ + ClipValuesSymmetric(1e4, exclude_targets=["view"]), + ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), + TupleOutput(), +] + def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): """Makes causal mask used for bi-directional self-attention. @@ -78,9 +85,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] expanded_mask = mask[:, None, None, :] inverted_mask = 1.0 - expanded_mask - # Using FLOAT16_LIMIT instead of -float("inf") to avoid NaNs on the IPUs. - inverted_mask = -FLOAT16_LIMIT * inverted_mask - return inverted_mask.to(dtype) + inverted_mask = -float("inf") * inverted_mask + return inverted_mask class _BartAttentionWithoutException(BartAttention): @@ -148,10 +154,10 @@ def forward( src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" - ) + # if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + # raise ValueError( + # f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + # ) if attention_mask is not None: attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask @@ -181,10 +187,10 @@ def forward( attn_output = torch.bmm(attn_probs, value_states) - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" - ) + # if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + # raise ValueError( + # f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + # ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) @@ -198,493 +204,49 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -class _BartEncoderLayerNoClamp(BartEncoderLayer): - """ - Same as BartEncoderLayer except it removed the dynamic if statement - for clamping fp16 tensor values. - """ - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: torch.FloatTensor, - layer_head_mask: torch.FloatTensor, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` - attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - # Change: removing this `if` because it can't be statically compiled. - # if hidden_states.dtype == torch.float16 and ( - # torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() - # ): - # clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - # hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -class _BartEncoderWithCustomExpandMask(BartEncoder): - """The same as BartEncoder but uses a custom version of _expand_mask. - - Check the _expand_mask docstring for more information. - """ - - def forward( - self, - input_ids=None, - attention_mask=None, - head_mask=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=False, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - embed_pos = self.embed_positions(input_shape) - - hidden_states = inputs_embeds + embed_pos - hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - if head_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." - ) - - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - dropout_probability = random.uniform(0, 1) - if self.training and (dropout_probability < self.layerdrop): # skip the layer - layer_outputs = (None, None) - else: - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - - -class _BartDecoderWithCustomMakeCausalAndExpandMask(BartDecoder): - """The same as BartDecoder but uses a custom version of _make_causal_mask and _expand_mask. - - Check the _expand_mask docstring for more information. - """ - - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def forward( - self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - - # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) - - hidden_states = inputs_embeds + positions - hidden_states = self.layernorm_embedding(hidden_states) - - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None - - # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): - if attn_mask is not None: - pass - # if attn_mask.size()[0] != (len(self.layers)): - # raise ValueError( - # "The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." - # ) - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - dropout_probability = random.uniform(0, 1) - if self.training and (dropout_probability < self.layerdrop): - continue - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, use_cache) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - ) - else: - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - -class _BartModelWithSharedEmbedding(BartModel): - @property - def is_encoder_and_decoder_embeddings_computation_shared(self): - return isinstance(self.shared, SharedEmbedding) - - def encoder_and_decoder_embeddings_computation(self, use_shared_embedding: bool): - """Sets the BartModel shared embedding layer to SharedEmbedding that combines the computation under one layer. - - Args: - use_shared_embedding: whether to use SharedEmbedding or not. - """ - - if use_shared_embedding: - if isinstance(self.shared, SharedEmbedding): - logger.warning("encoder and decoder embeddings computation is already shared") - else: - self.shared = SharedEmbedding(self.shared) - else: - if isinstance(self.shared, nn.Embedding): - logger.warning("encoder and decoder embeddings computation is not shared") - else: - self.shared = self.shared.shared - - def change_bart_encoder_and_decoder_classes(self, restore: bool): - """Changes the encoder and decoder classes to update their forward pass so that they use our custom versions of - _make_causal_mask and _expand_mask. - - Args: - restore: whether to restore the encoder and decoder to their original version or not. - """ - self.encoder.__class__ = BartEncoder if restore else _BartEncoderWithCustomExpandMask - self.decoder.__class__ = BartDecoder if restore else _BartDecoderWithCustomMakeCausalAndExpandMask - for layer in self.encoder.layers: - layer.__class__ = BartEncoderLayer if restore else _BartEncoderLayerNoClamp - - def change_bart_attention_class(self, restore: bool): - """Changes the attention layers to either use the original BartAttention forward or - BartAttentionWithoutException forward. - - Args: - restore: whether to restore the attention layers to their original version or not. - """ - new_cls = BartAttention if restore else _BartAttentionWithoutException - for mod in self.modules(): - if isinstance(mod, BartAttention): - mod.__class__ = new_cls - - def forward( - self, - input_ids=None, - attention_mask=None, - decoder_input_ids=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - encoder_outputs=None, - past_key_values=None, - inputs_embeds=None, - decoder_inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - - # different to other models, Bart automatically creates decoder_input_ids from - # input_ids if no decoder_input_ids are provided - if decoder_input_ids is None and decoder_inputs_embeds is None: - if input_ids is None: - raise ValueError( - "If no `decoder_input_ids` or `decoder_inputs_embeds` are " - "passed, `input_ids` cannot be `None`. Please pass either " - "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." - ) - - decoder_input_ids = shift_tokens_right( - input_ids, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if encoder_outputs is None: - if self.is_encoder_and_decoder_embeddings_computation_shared: - inputs_embeds, decoder_inputs_embeds = self.shared( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - encoder_embed_scale=self.encoder.embed_scale, - decoder_embed_scale=self.decoder.embed_scale, - ) - if inputs_embeds is not None: - input_ids = None - if decoder_inputs_embeds is not None: - decoder_input_ids = None - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return Seq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) - - @register(BartForConditionalGeneration) class PipelinedBartForConditionalGeneration( GenerationMethodsMixin, BartForConditionalGeneration, PipelineMixin, IPUGenerationMixin ): + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + transformations = [ + AddPoptorchBlock("Embedding", 0, "model.shared", log_insertions=log_insertions), + # AddPoptorchBlock("Embedding", 0, "model.encoder.embed_positions"), + # AddPoptorchBlock("Embedding", 0, "model.encoder.layernorm_embedding"), + # AddPoptorchBlock("Embedding", 0, "model.decoder.embed_positions"), + # AddPoptorchBlock("Embedding", 0, "model.decoder.layernorm_embedding"), + AddPoptorchBlocksInSeries( + "Encoder", layer_ipu[: self.config.encoder_layers], r"model.encoder.layers.[0-9]+", log_insertions=log_insertions + ), + AddPoptorchBlocksInSeries( + "Decoder", + layer_ipu[self.config.encoder_layers:], + r"model.decoder.layers.[0-9]+", + log_insertions=log_insertions, + ), + AddPoptorchBlock("LM Head Output", 0, "lm_head", log_insertions=log_insertions), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations += [ + RecomputationCheckpoint( + "model.encoder.layers.[0-9]+", to_exclude=f"model.encoder.layers.{self.config.encoder_layers - 1}" + ), + RecomputationCheckpoint( + "model.decoder.layers.[0-9]+", to_exclude=f"model.decoder.layers.{self.config.decoder_layers - 1}" + ), + ] + + if not isinstance(self, torch.fx.GraphModule): + if self.ipu_config.embedding_serialization_factor > 1: + transformations += [ + LinearToSerializedLinear("lm_head"), + TieWeights("model.shared", "lm_head"), + ] + transformations += [ShareEmbeddingComputation()] + return transformations + def parallelize(self): """ Transform the model to run in an IPU pipeline. @@ -698,62 +260,26 @@ def parallelize(self): ``` """ super().parallelize() - - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - - logger.info("-------------------- Device Allocation --------------------") - logger.info("Embedding --> IPU 0") - - if self.ipu_config.embedding_serialization_factor > 1: - serialized_lm_head = SerializedLinear( - self.config.d_model, - self.model.shared.num_embeddings, - self.ipu_config.embedding_serialization_factor, - bias=False, - mode=poptorch.MatMulSerializationMode.OutputChannels, - ) - serialized_lm_head.load_state_dict(self.lm_head.state_dict()) - self.lm_head = serialized_lm_head - self.tie_weights() - - self.model.__class__ = _BartModelWithSharedEmbedding - self.model.encoder_and_decoder_embeddings_computation(True) - self.model.change_bart_encoder_and_decoder_classes(False) - self.model.change_bart_attention_class(False) - - self.model.shared = poptorch.BeginBlock(self.model.shared, "Embedding", ipu_id=0) - self.model.encoder.embed_positions = poptorch.BeginBlock( - self.model.encoder.embed_positions, "Embedding", ipu_id=0 - ) - self.model.encoder.layernorm_embedding = poptorch.BeginBlock( - self.model.encoder.layernorm_embedding, "Embedding", ipu_id=0 - ) - - for index, layer in enumerate(self.model.encoder.layers): - ipu = layer_ipu[index] - if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: - recomputation_checkpoint(layer) - self.model.encoder.layers[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) - logger.info(f"Encoder {index:<2} --> IPU {ipu}") - - self.model.decoder.embed_positions = poptorch.BeginBlock( - self.model.decoder.embed_positions, "Embedding", ipu_id=0 - ) - self.model.decoder.layernorm_embedding = poptorch.BeginBlock( - self.model.decoder.layernorm_embedding, "Embedding", ipu_id=0 - ) - shift = len(self.model.encoder.layers) - for index, layer in enumerate(self.model.decoder.layers): - ipu = layer_ipu[index + shift] - if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: - recomputation_checkpoint(layer) - self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu) - logger.info(f"Decoder {index:<2} --> IPU {ipu}") - - logger.info("LM Head Output --> IPU 0") - self.lm_head = poptorch.BeginBlock(self.lm_head, "LM Head Output", ipu_id=0) - logger.info("-----------------------------------------------------------") - return self + if not isinstance(self, torch.fx.GraphModule): + orig_make_causal_mask = transformers.models.bart.modeling_bart._make_causal_mask + orig_expand_mask = transformers.models.bart.modeling_bart._expand_mask + transformers.models.bart.modeling_bart._make_causal_mask = _make_causal_mask + transformers.models.bart.modeling_bart._expand_mask = _expand_mask + for mod in self.modules(): + if isinstance(mod, BartAttention): + mod.__class__ = _BartAttentionWithoutException + traced = symbolic_trace_pipelined_model(self) + transformers.models.bart.modeling_bart._make_causal_mask = orig_make_causal_mask + transformers.models.bart.modeling_bart._expand_mask = orig_expand_mask + else: + traced = self + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced def deparallelize(self): """ @@ -762,21 +288,11 @@ def deparallelize(self): fully compatible with `transformers.BartForConditionalGeneration`. """ super().deparallelize() - self.model.encoder_and_decoder_embeddings_computation(False) - self.model.change_bart_encoder_and_decoder_classes(True) - self.model.change_bart_attention_class(True) - self.model.__class__ = BartModel - - if self.ipu_config.embedding_serialization_factor > 1: - old_lm_head = nn.Linear( - self.config.d_model, - self.model.shared.num_embeddings, - bias=False, - ) - old_lm_head.load_state_dict(self.lm_head.state_dict()) - self.lm_head = old_lm_head - self.tie_weights() - + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations = [t for t in transformations if isinstance(t, ReversibleTransformation)] + composition = compose(*transformations) + self = composition(self, reverse=True) return self def forward( diff --git a/optimum/graphcore/models/hubert/modeling_hubert.py b/optimum/graphcore/models/hubert/modeling_hubert.py index 77b77f621..b274d9ead 100644 --- a/optimum/graphcore/models/hubert/modeling_hubert.py +++ b/optimum/graphcore/models/hubert/modeling_hubert.py @@ -11,53 +11,80 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import torch import poptorch -from optimum.utils import logging from transformers import HubertForSequenceClassification from transformers.models.hubert.modeling_hubert import HubertEncoder, HubertEncoderStableLayerNorm -from ...modeling_utils import PipelineMixin, get_layer_ipu, recomputation_checkpoint, register -from .ipu_layer_drop import IPUHubertEncoder, IPUHubertEncoderStableLayerNorm +from ....utils import logging +from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ...modeling_utils import PipelineMixin, get_layer_ipu, register +from ...fx.utils import symbolic_trace_pipelined_model +from ...fx.transformations import ( + AddPoptorchBlock, + AddPoptorchBlocksInSeries, + ClipValues, + ClipValuesSymmetric, + RecomputationCheckpoint, + TupleOutput, +) logger = logging.get_logger(__name__) +_OPTIMIZATION_TRANSFORMATIONS = [ + ChangeTrueDivToMulByInverse(), + MergeLinears(), + # FuseBiasInLinear(), +] + +_NON_REVERSIBLE_TRANSFORMATIONS = [ + ClipValuesSymmetric(1e4, exclude_targets=["view"]), + ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), + TupleOutput(), +] + @register(HubertForSequenceClassification) class PipelinedHubertForSequenceClassification(HubertForSequenceClassification, PipelineMixin): - def change_hubert_encoder_class(self, restore: bool): - """Changes the encoder class to update its forward pass so that it uses our custom version. - - Args: - restore: whether to restore the encoder to its original version or not. - """ - if self.config.do_stable_layer_norm: - new_cls = HubertEncoderStableLayerNorm if restore else IPUHubertEncoderStableLayerNorm - else: - new_cls = HubertEncoder if restore else IPUHubertEncoder - self.hubert.encoder.__class__ = new_cls + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + transformations = [ + AddPoptorchBlock("Feature Extractor", 0, "hubert.feature_extractor", log_insertions=log_insertions), + AddPoptorchBlock("Feature Projection", 0, "hubert.feature_projection", log_insertions=log_insertions), + AddPoptorchBlock("Encoder", 0, "hubert.encoder", log_insertions=log_insertions), + AddPoptorchBlocksInSeries( + "Encoder", layer_ipu, r"hubert.encoder.layers.[0-9]+", log_insertions=log_insertions + ), + AddPoptorchBlock("Projector", layer_ipu[-1], "projector", log_insertions=log_insertions), + AddPoptorchBlock("Classifier", layer_ipu[-1], "classifier", log_insertions=log_insertions), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations += [ + RecomputationCheckpoint( + "hubert.encoder.layers.[0-9]+", to_exclude=f"hubert.encoder.layers.{self.config.num_layers - 1}" + ), + ] + return transformations def parallelize(self): super().parallelize() + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced - self.change_hubert_encoder_class(False) - - self.hubert.feature_extractor = poptorch.BeginBlock(self.hubert.feature_extractor, ipu_id=0) - self.hubert.feature_projection = poptorch.BeginBlock(self.hubert.feature_projection, ipu_id=0) - self.hubert.encoder = poptorch.BeginBlock(self.hubert.encoder, ipu_id=0) - - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - for index, layer in enumerate(self.hubert.encoder.layers): - # Put checkpoints on every encoder layer - h = recomputation_checkpoint(layer) - self._hooks.append(h) - ipu = layer_ipu[index] - self.hubert.encoder.layers[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) - - last_ipu = self.ipu_config.ipus_per_replica - 1 - self.projector = poptorch.BeginBlock(self.projector, ipu_id=last_ipu) - self.classifier = poptorch.BeginBlock(self.classifier, ipu_id=last_ipu) + def deparallelize(self): + super().deparallelize() + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + self = composition(self, reverse=True) return self def deparallelize(self): diff --git a/optimum/graphcore/models/t5/modeling_t5.py b/optimum/graphcore/models/t5/modeling_t5.py index 71daf9e11..646714a62 100644 --- a/optimum/graphcore/models/t5/modeling_t5.py +++ b/optimum/graphcore/models/t5/modeling_t5.py @@ -136,7 +136,7 @@ def get_transformations(self): log_insertions = self.ipu_config.log_insertions layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) transformations = [ - AddPoptorchBlock("Embedding", 0, "encoder.shared", log_insertions=log_insertions), + AddPoptorchBlock("Embedding", 0, "shared", log_insertions=log_insertions), AddPoptorchBlocksInSeries( "Encoder", layer_ipu[: self.config.num_layers], r"encoder.block.[0-9]+", log_insertions=log_insertions ), @@ -189,9 +189,7 @@ def parallelize(self): non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) traced = composition(traced) traced = non_reversible_composition(traced) - import ipdb - - ipdb.set_trace() + import ipdb; ipdb.set_trace() return traced # layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) @@ -220,11 +218,6 @@ def parallelize(self): # self.encoder.__class__ = CustomT5Stack # self.decoder.__class__ = CustomT5Stack - # Use a custom T5Block implementation that removes a dynamic if blocks that can't be statically traced - for block in self.encoder.block: - block.__class__ = CustomT5Block - for block in self.decoder.block: - block.__class__ = CustomT5Block # for index, layer in enumerate(self.encoder.block): # ipu = layer_ipu[index] diff --git a/optimum/graphcore/models/vit/modeling_vit.py b/optimum/graphcore/models/vit/modeling_vit.py index 856b9ee04..281a5b8ac 100644 --- a/optimum/graphcore/models/vit/modeling_vit.py +++ b/optimum/graphcore/models/vit/modeling_vit.py @@ -11,38 +11,74 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch -import poptorch import transformers -from optimum.utils import logging -from ...modeling_utils import PipelineMixin, get_layer_ipu, recomputation_checkpoint, register +from ....utils import logging +from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ...modeling_utils import PipelineMixin, get_layer_ipu, register +from ...fx.utils import symbolic_trace_pipelined_model +from ...fx.transformations import ( + AddPoptorchBlock, + AddPoptorchBlocksInSeries, + ClipValues, + ClipValuesSymmetric, + RecomputationCheckpoint, + TupleOutput, +) logger = logging.get_logger(__name__) +_OPTIMIZATION_TRANSFORMATIONS = [ + ChangeTrueDivToMulByInverse(), + MergeLinears(), + # FuseBiasInLinear(), +] + +_NON_REVERSIBLE_TRANSFORMATIONS = [ + ClipValuesSymmetric(1e4, exclude_targets=["view"]), + ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), + TupleOutput(), +] + @register(transformers.ViTForImageClassification) class PipelinedViTForImageClassification(transformers.ViTForImageClassification, PipelineMixin): + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + transformations = [ + AddPoptorchBlock("Embedding", 0, "vit.embeddings", log_insertions=log_insertions), + AddPoptorchBlocksInSeries( + "Encoder", layer_ipu, r"vit.encoder.layer.[0-9]+", log_insertions=log_insertions + ), + AddPoptorchBlock("LayerNorm Head Output", layer_ipu[-1], "vit.layernorm", log_insertions=log_insertions), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations += [ + RecomputationCheckpoint( + "vit.encoder.layer.[0-9]+", to_exclude=f"vit.encoder.layer.{self.config.num_layers - 1}" + ), + ] + return transformations + def parallelize(self): super().parallelize() - logger.info("---------- Device Allocation -----------") - logger.info("Embedding --> IPU 0") - self.vit.embeddings = poptorch.BeginBlock(self.vit.embeddings, "Embedding", ipu_id=0) + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - for index, layer in enumerate(self.vit.encoder.layer): - if self.ipu_config.recompute_checkpoint_every_layer: - # Put checkpoints on every encoder layer - h = recomputation_checkpoint(layer) - self._hooks.append(h) - ipu = layer_ipu[index] - logger.info(f"Encoder {index:<2} --> IPU {ipu}") - self.vit.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) - - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"Head --> IPU {last_ipu}") - logger.info("---------------------------------------") - self.vit.layernorm = poptorch.BeginBlock(self.vit.layernorm, "LayerNorm", ipu_id=last_ipu) - self.classifier = poptorch.BeginBlock(self.classifier, "Classifier", ipu_id=last_ipu) + def deparallelize(self): + super().deparallelize() + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + self = composition(self, reverse=True) return self diff --git a/optimum/graphcore/trainer.py b/optimum/graphcore/trainer.py index b5399235e..2ba3bb47e 100644 --- a/optimum/graphcore/trainer.py +++ b/optimum/graphcore/trainer.py @@ -479,7 +479,8 @@ def compile_model( model.input_names_for_symbolic_trace = list(signature.parameters.keys())[: len(sample_batch)] else: model.input_names_for_symbolic_trace = list(sample_batch.keys()) - model = model.parallelize() + if not isinstance(model, torch.fx.GraphModule): + model = model.parallelize() if not self.args.fp32: model.half() if training: From b54758bd50182afcdac7d1435142843f0b1a8edf Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 11 Aug 2022 18:11:45 +0200 Subject: [PATCH 09/33] [WIP] Roberta --- .../graphcore/models/bart/modeling_bart.py | 48 ++-- .../graphcore/models/bert/modeling_bert.py | 6 - .../models/hubert/modeling_hubert.py | 10 +- .../models/roberta/modeling_roberta.py | 206 +++++++++--------- optimum/graphcore/models/t5/modeling_t5.py | 117 ++-------- optimum/graphcore/models/vit/modeling_vit.py | 8 +- 6 files changed, 154 insertions(+), 241 deletions(-) diff --git a/optimum/graphcore/models/bart/modeling_bart.py b/optimum/graphcore/models/bart/modeling_bart.py index 524a4db0d..bf64f7407 100644 --- a/optimum/graphcore/models/bart/modeling_bart.py +++ b/optimum/graphcore/models/bart/modeling_bart.py @@ -14,22 +14,13 @@ from typing import Optional, Tuple import torch -import torch.nn as nn import transformers +from optimum.utils import logging from transformers import BartForConditionalGeneration from transformers.models.bart.modeling_bart import BartAttention -from optimum.utils import logging from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, ReversibleTransformation, compose -from ...generation_utils import IPUGenerationMixin -from ...modeling_utils import ( - GenerationMethodsMixin, - PipelineMixin, - get_layer_ipu, - register, -) -from ...fx.utils import symbolic_trace_pipelined_model from ...fx.transformations import ( AddPoptorchBlock, AddPoptorchBlocksInSeries, @@ -41,6 +32,9 @@ TieWeights, TupleOutput, ) +from ...fx.utils import symbolic_trace_pipelined_model +from ...generation_utils import IPUGenerationMixin +from ...modeling_utils import GenerationMethodsMixin, PipelineMixin, get_layer_ipu, register logger = logging.get_logger(__name__) @@ -163,7 +157,7 @@ def forward( attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): @@ -183,7 +177,7 @@ def forward( else: attn_weights_reshaped = None - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_probs = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) @@ -218,11 +212,14 @@ def get_transformations(self): # AddPoptorchBlock("Embedding", 0, "model.decoder.embed_positions"), # AddPoptorchBlock("Embedding", 0, "model.decoder.layernorm_embedding"), AddPoptorchBlocksInSeries( - "Encoder", layer_ipu[: self.config.encoder_layers], r"model.encoder.layers.[0-9]+", log_insertions=log_insertions + "Encoder", + layer_ipu[: self.config.encoder_layers], + r"model.encoder.layers.[0-9]+", + log_insertions=log_insertions, ), AddPoptorchBlocksInSeries( "Decoder", - layer_ipu[self.config.encoder_layers:], + layer_ipu[self.config.encoder_layers :], r"model.decoder.layers.[0-9]+", log_insertions=log_insertions, ), @@ -260,19 +257,16 @@ def parallelize(self): ``` """ super().parallelize() - if not isinstance(self, torch.fx.GraphModule): - orig_make_causal_mask = transformers.models.bart.modeling_bart._make_causal_mask - orig_expand_mask = transformers.models.bart.modeling_bart._expand_mask - transformers.models.bart.modeling_bart._make_causal_mask = _make_causal_mask - transformers.models.bart.modeling_bart._expand_mask = _expand_mask - for mod in self.modules(): - if isinstance(mod, BartAttention): - mod.__class__ = _BartAttentionWithoutException - traced = symbolic_trace_pipelined_model(self) - transformers.models.bart.modeling_bart._make_causal_mask = orig_make_causal_mask - transformers.models.bart.modeling_bart._expand_mask = orig_expand_mask - else: - traced = self + orig_make_causal_mask = transformers.models.bart.modeling_bart._make_causal_mask + orig_expand_mask = transformers.models.bart.modeling_bart._expand_mask + transformers.models.bart.modeling_bart._make_causal_mask = _make_causal_mask + transformers.models.bart.modeling_bart._expand_mask = _expand_mask + for mod in self.modules(): + if isinstance(mod, BartAttention): + mod.__class__ = _BartAttentionWithoutException + traced = symbolic_trace_pipelined_model(self) + transformers.models.bart.modeling_bart._make_causal_mask = orig_make_causal_mask + transformers.models.bart.modeling_bart._expand_mask = orig_expand_mask transformations = self.get_transformations() transformations += _OPTIMIZATION_TRANSFORMATIONS composition = compose(*transformations) diff --git a/optimum/graphcore/models/bert/modeling_bert.py b/optimum/graphcore/models/bert/modeling_bert.py index ebb0d89e5..b5760d98b 100644 --- a/optimum/graphcore/models/bert/modeling_bert.py +++ b/optimum/graphcore/models/bert/modeling_bert.py @@ -445,8 +445,6 @@ class PipelinedBertForSequenceClassification(BertForSequenceClassification, Bert ``` """ - pass - @register(BertForMultipleChoice) class PipelinedBertForMultipleChoice(BertForMultipleChoice, BertPipelineMixin): @@ -459,8 +457,6 @@ class PipelinedBertForMultipleChoice(BertForMultipleChoice, BertPipelineMixin): ``` """ - pass - @register(BertForTokenClassification) class PipelinedBertForTokenClassification(BertForTokenClassification, BertPipelineMixin): @@ -473,8 +469,6 @@ class PipelinedBertForTokenClassification(BertForTokenClassification, BertPipeli ``` """ - pass - @register(BertForQuestionAnswering) class PipelinedBertForQuestionAnswering(BertForQuestionAnswering, BertPipelineMixin): diff --git a/optimum/graphcore/models/hubert/modeling_hubert.py b/optimum/graphcore/models/hubert/modeling_hubert.py index b274d9ead..577dad313 100644 --- a/optimum/graphcore/models/hubert/modeling_hubert.py +++ b/optimum/graphcore/models/hubert/modeling_hubert.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch + import poptorch from transformers import HubertForSequenceClassification from transformers.models.hubert.modeling_hubert import HubertEncoder, HubertEncoderStableLayerNorm -from ....utils import logging from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose -from ...modeling_utils import PipelineMixin, get_layer_ipu, register -from ...fx.utils import symbolic_trace_pipelined_model +from ....utils import logging from ...fx.transformations import ( AddPoptorchBlock, AddPoptorchBlocksInSeries, @@ -28,6 +27,8 @@ RecomputationCheckpoint, TupleOutput, ) +from ...fx.utils import symbolic_trace_pipelined_model +from ...modeling_utils import PipelineMixin, get_layer_ipu, register logger = logging.get_logger(__name__) @@ -63,7 +64,8 @@ def get_transformations(self): if self.ipu_config.recompute_checkpoint_every_layer: transformations += [ RecomputationCheckpoint( - "hubert.encoder.layers.[0-9]+", to_exclude=f"hubert.encoder.layers.{self.config.num_layers - 1}" + "hubert.encoder.layers.[0-9]+", + to_exclude=f"hubert.encoder.layers.{self.config.num_hidden_layers - 1}", ), ] return transformations diff --git a/optimum/graphcore/models/roberta/modeling_roberta.py b/optimum/graphcore/models/roberta/modeling_roberta.py index ba721bb77..86db73511 100644 --- a/optimum/graphcore/models/roberta/modeling_roberta.py +++ b/optimum/graphcore/models/roberta/modeling_roberta.py @@ -29,6 +29,20 @@ ) from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput +from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ...fx.transformations import ( + AddPoptorchBlock, + AddPoptorchBlocksInSeries, + ClipValues, + ClipValuesSymmetric, + LinearToSerializedLinear, + OutlineAttribute, + RecomputationCheckpoint, + TieWeights, + TupleOutput, + VocabEmbeddingToSerializedEmbedding, +) +from ...fx.utils import symbolic_trace_pipelined_model from ...modeling_utils import ( OnehotGather, PipelineMixin, @@ -43,8 +57,45 @@ logger = logging.get_logger(__name__) +_OPTIMIZATION_TRANSFORMATIONS = [ + ChangeTrueDivToMulByInverse(), + MergeLinears(), + # FuseBiasInLinear(), +] + +_NON_REVERSIBLE_TRANSFORMATIONS = [ + ClipValuesSymmetric(1e4, exclude_targets=["view"]), + ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), + TupleOutput(), +] + class RobertaPipelineMixin(PipelineMixin): + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + last_ipu = len(self.ipu_config.layers_per_ipu) - 1 + transformations = [ + AddPoptorchBlock("Embedding", 0, "roberta.embeddings", log_insertions=log_insertions), + OutlineAttribute("roberta.embeddings.LayerNorm", "Embedding"), + AddPoptorchBlocksInSeries( + "Encoder", layer_ipu, r"roberta.encoder.layer.[0-9]+", log_insertions=log_insertions + ), + # Only one of the following AddPoptorchBlock, will actually add a block. + AddPoptorchBlock("Classifier Output", last_ipu, "classifier", log_insertions=log_insertions), + AddPoptorchBlock("QA Outputs", last_ipu, "qa_outputs", log_insertions=log_insertions), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations.append( + RecomputationCheckpoint( + "roberta.encoder.layer.[0-9]+", + to_exclude=f"roberta.encoder.layer.{self.config.num_hidden_layers - 1}", + ) + ) + if self.ipu_config.embedding_serialization_factor > 1: + transformations.append(VocabEmbeddingToSerializedEmbedding()) + return transformations + def parallelize(self): """ Transform the Roberta model body to run in an IPU pipeline. @@ -53,37 +104,26 @@ def parallelize(self): - Adds recomputation checkpoints """ super().parallelize() - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - - logger.info("-------------------- Device Allocation --------------------") - logger.info("Embedding --> IPU 0") - if self.ipu_config.embedding_serialization_factor > 1: - self.roberta.embeddings.word_embeddings = SerializedEmbedding( - self.roberta.embeddings.word_embeddings, self.ipu_config.embedding_serialization_factor - ) - self.roberta.embeddings = poptorch.BeginBlock(self.roberta.embeddings, "Embedding", ipu_id=0) - hs = outline_attribute(self.roberta.embeddings.LayerNorm, "embedding") - self._hooks.extend(hs) - - for index, layer in enumerate(self.roberta.encoder.layer): - ipu = layer_ipu[index] - if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: - h = recomputation_checkpoint(layer) - self._hooks.append(h) - self.roberta.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) - logger.info(f"Encoder {index:<2} --> IPU {ipu}") - return self + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced def deparallelize(self): """ Undo the changes to the model done by `parallelize`. You should call this before doing `save_pretrained` so that the `model.state_dict` is - fully compatible with `transformers.RobertaForSequenceClassification`. + fully compatible with the original model. """ super().deparallelize() - # Deserialize the serialized word embedding - if self.ipu_config.embedding_serialization_factor > 1: - self.roberta.embeddings.word_embeddings = self.roberta.embeddings.word_embeddings.deserialize() + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + self = composition(self, reverse=True) return self @@ -102,6 +142,37 @@ def __init__(self, config): super().__init__(config) self.gather_indices = OnehotGather() + # def get_ops_to_wrap_for_tracing(self): + # return [ + # ("torch.topk", *_gen_constructor_wrapper(torch.topk)), + # ("torch.nn.functional.one_hot", *_gen_constructor_wrapper(torch.nn.functional.one_hot)), + # ] + + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + transformations = [ + AddPoptorchBlock("Embedding", 0, "roberta.embeddings", log_insertions=log_insertions), + OutlineAttribute("roberta.embeddings.LayerNorm", "Embedding"), + AddPoptorchBlocksInSeries( + "Encoder", layer_ipu, r"roberta.encoder.layer.[0-9]+", log_insertions=log_insertions + ), + AddPoptorchBlock("LM Head", 0, "lm_head", log_insertions=log_insertions), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations.append( + RecomputationCheckpoint( + "roberta.encoder.layer.[0-9]+", + to_exclude=f"roberta.encoder.layer.{self.config.num_hidden_layers - 1}", + ) + ) + if self.ipu_config.embedding_serialization_factor > 1: + transformations += [ + LinearToSerializedLinear("lm_head.decoder"), + TieWeights("roberta.embeddings.word_embeddings", "lm_head.decoder"), + ] + return transformations + def parallelize(self): """ Transform the model to run in an IPU pipeline. @@ -110,39 +181,14 @@ def parallelize(self): - Adds recomputation checkpoints """ super().parallelize() - - if self.ipu_config.embedding_serialization_factor > 1: - serialized_decoder = SerializedLinear( - self.config.hidden_size, - self.config.vocab_size, - self.ipu_config.embedding_serialization_factor, - bias=True, - mode=poptorch.MatMulSerializationMode.OutputChannels, - ) - serialized_decoder.load_state_dict(self.lm_head.decoder.state_dict()) - self.lm_head.decoder = serialized_decoder - self.tie_weights() - - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - - logger.info("-------------------- Device Allocation --------------------") - logger.info("Embedding --> IPU 0") - self.roberta.embeddings = poptorch.BeginBlock(self.roberta.embeddings, "Embedding", ipu_id=0) - hs = outline_attribute(self.roberta.embeddings.LayerNorm, "embedding") - self._hooks.extend(hs) - - for index, layer in enumerate(self.roberta.encoder.layer): - ipu = layer_ipu[index] - if self.ipu_config.recompute_checkpoint_every_layer: - h = recomputation_checkpoint(layer) - self._hooks.append(h) - self.roberta.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) - logger.info(f"Encoder {index:<2} --> IPU {ipu}") - - logger.info("LM Head --> IPU 0") - self.lm_head = poptorch.BeginBlock(self.lm_head, "LM Head", ipu_id=0) - logger.info("-----------------------------------------------------------") - return self + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced def deparallelize(self): """ @@ -151,16 +197,10 @@ def deparallelize(self): compatible with the original model. """ super().deparallelize() - - if self.ipu_config.embedding_serialization_factor > 1: - decoder = nn.Linear( - self.config.hidden_size, - self.config.vocab_size, - bias=True, - ) - decoder.load_state_dict(self.lm_head.decoder.state_dict()) - self.lm_head.decoder = decoder - self.tie_weights() + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + self = composition(self, reverse=True) return self def forward( @@ -247,14 +287,6 @@ class PipelinedRobertaForSequenceClassification(RobertaForSequenceClassification ``` """ - def parallelize(self): - super().parallelize() - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"Classifier Output --> IPU {last_ipu}") - self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self - @register(RobertaForMultipleChoice) class PipelinedRobertaForMultipleChoice(RobertaForMultipleChoice, RobertaPipelineMixin): @@ -267,14 +299,6 @@ class PipelinedRobertaForMultipleChoice(RobertaForMultipleChoice, RobertaPipelin ``` """ - def parallelize(self): - super().parallelize() - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"Classifier Output --> IPU {last_ipu}") - self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self - @register(RobertaForTokenClassification) class PipelinedRobertaForTokenClassification(RobertaForTokenClassification, RobertaPipelineMixin): @@ -287,14 +311,6 @@ class PipelinedRobertaForTokenClassification(RobertaForTokenClassification, Robe ``` """ - def parallelize(self): - super().parallelize() - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"Classifier Output --> IPU {last_ipu}") - self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self - @register(RobertaForQuestionAnswering) class PipelinedRobertaForQuestionAnswering(RobertaForQuestionAnswering, RobertaPipelineMixin): @@ -307,14 +323,6 @@ class PipelinedRobertaForQuestionAnswering(RobertaForQuestionAnswering, RobertaP ``` """ - def parallelize(self): - super().parallelize() - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"QA Outputs --> IPU {last_ipu}") - self.qa_outputs = poptorch.BeginBlock(self.qa_outputs, "QA Outputs", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self - def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/optimum/graphcore/models/t5/modeling_t5.py b/optimum/graphcore/models/t5/modeling_t5.py index 646714a62..d6bb08974 100644 --- a/optimum/graphcore/models/t5/modeling_t5.py +++ b/optimum/graphcore/models/t5/modeling_t5.py @@ -22,34 +22,23 @@ from optimum.utils import logging from transformers import T5ForConditionalGeneration from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput -from transformers.models.t5.modeling_t5 import __HEAD_MASK_WARNING_MSG, T5Block, T5Stack +from transformers.models.t5.modeling_t5 import __HEAD_MASK_WARNING_MSG, T5LayerNorm -from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, ReversibleTransformation, compose from ...fx.transformations import ( AddPoptorchBlock, AddPoptorchBlocksInSeries, ClipValues, ClipValuesSymmetric, LinearToSerializedLinear, - OutlineAttribute, RecomputationCheckpoint, ShareEmbeddingComputation, TieWeights, TupleOutput, - VocabEmbeddingToSerializedEmbedding, ) from ...fx.utils import symbolic_trace_pipelined_model from ...generation_utils import IPUGenerationMixin -from ...modeling_utils import ( - GenerationMethodsMixin, - OnehotGather, - PipelineMixin, - SerializedLinear, - SharedEmbedding, - get_layer_ipu, - recomputation_checkpoint, - register, -) +from ...modeling_utils import GenerationMethodsMixin, PipelineMixin, SharedEmbedding, get_layer_ipu, register logger = logging.get_logger(__name__) @@ -158,12 +147,13 @@ def get_transformations(self): ), ] - if self.ipu_config.embedding_serialization_factor > 1: - transformations += [ - LinearToSerializedLinear("lm_head"), - TieWeights("shared", "lm_head"), - ] - transformations += [ShareEmbeddingComputation()] + if not isinstance(self, torch.fx.GraphModule): + if self.ipu_config.embedding_serialization_factor > 1: + transformations += [ + LinearToSerializedLinear("lm_head"), + TieWeights("shared", "lm_head"), + ] + transformations += [ShareEmbeddingComputation()] return transformations def parallelize(self): @@ -184,68 +174,12 @@ def parallelize(self): mod.forward = poptorch.autocast(enabled=True)(mod.forward) traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - # transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += _OPTIMIZATION_TRANSFORMATIONS composition = compose(*transformations) non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) traced = composition(traced) traced = non_reversible_composition(traced) - import ipdb; ipdb.set_trace() return traced - # layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - - # logger.info("-------------------- Device Allocation --------------------") - # logger.info("Embedding --> IPU 0") - - # if self.ipu_config.embedding_serialization_factor > 1: - # serialized_lm_head = SerializedLinear( - # self.config.d_model, - # self.shared.num_embeddings, - # self.ipu_config.embedding_serialization_factor, - # bias=False, - # mode=poptorch.MatMulSerializationMode.OutputChannels, - # ) - # serialized_lm_head.load_state_dict(self.lm_head.state_dict()) - # self.lm_head = serialized_lm_head - # # TODO: is it needed to check? - # if self.config.tie_word_embeddings: - # self.tie_weights() - - # self.scale_down_weights(factor=1) - # self.encoder_and_decoder_embeddings_computation(True) - # self.shared = poptorch.BeginBlock(self.shared, "Embedding", ipu_id=0) - - # Use a custom T5Stack implementation because sharing the position bias causes OOM error - # self.encoder.__class__ = CustomT5Stack - # self.decoder.__class__ = CustomT5Stack - - - # for index, layer in enumerate(self.encoder.block): - # ipu = layer_ipu[index] - # if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1: - # recomputation_checkpoint(layer) - # self.encoder.block[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) - # logger.info(f"Encoder {index:<2} --> IPU {ipu}") - - # self.encoder.final_layer_norm = poptorch.BeginBlock( - # self.encoder.final_layer_norm, "Encoder Stack Final LayerNorm", ipu_id=ipu - # ) - - # shift = len(self.encoder.block) - # for index, layer in enumerate(self.decoder.block): - # ipu = layer_ipu[index + shift] - # if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1: - # recomputation_checkpoint(layer) - # self.decoder.block[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu) - # logger.info(f"Decoder {index:<2} --> IPU {ipu}") - - # self.decoder.final_layer_norm = poptorch.BeginBlock( - # self.decoder.final_layer_norm, "Decoder Stack Final LayerNorm", ipu_id=ipu - # ) - - # logger.info("LM Head Output --> IPU 0") - # self.lm_head = poptorch.BeginBlock(self.lm_head, "LM Head Output", ipu_id=0) - # logger.info("-----------------------------------------------------------") - return self def deparallelize(self): """ @@ -255,30 +189,11 @@ def deparallelize(self): """ # T5ForConditionalGeneration has a deparallelize method, so make sure that the PipelineMixin one is used here. PipelineMixin.deparallelize(self) - - self.encoder_and_decoder_embeddings_computation(False) - # self.scale_down_weights(factor=1, restore=True) - - self.encoder.__class__ = T5Stack - self.decoder.__class__ = T5Stack - - for block in self.encoder.block: - block.__class__ = T5Block - for block in self.decoder.block: - block.__class__ = T5Block - - if self.ipu_config.embedding_serialization_factor > 1: - old_lm_head = nn.Linear( - self.config.d_model, - self.shared.num_embeddings, - bias=False, - ) - old_lm_head.load_state_dict(self.lm_head.state_dict()) - self.lm_head = old_lm_head - # TODO: is it needed to check? - if self.config.tie_word_embeddings: - self.tie_weights() - + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations = [t for t in transformations if isinstance(t, ReversibleTransformation)] + composition = compose(*transformations) + self = composition(self, reverse=True) return self def forward( diff --git a/optimum/graphcore/models/vit/modeling_vit.py b/optimum/graphcore/models/vit/modeling_vit.py index 281a5b8ac..ef36046db 100644 --- a/optimum/graphcore/models/vit/modeling_vit.py +++ b/optimum/graphcore/models/vit/modeling_vit.py @@ -15,10 +15,8 @@ import transformers -from ....utils import logging from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose -from ...modeling_utils import PipelineMixin, get_layer_ipu, register -from ...fx.utils import symbolic_trace_pipelined_model +from ....utils import logging from ...fx.transformations import ( AddPoptorchBlock, AddPoptorchBlocksInSeries, @@ -27,6 +25,8 @@ RecomputationCheckpoint, TupleOutput, ) +from ...fx.utils import symbolic_trace_pipelined_model +from ...modeling_utils import PipelineMixin, get_layer_ipu, register logger = logging.get_logger(__name__) @@ -59,7 +59,7 @@ def get_transformations(self): if self.ipu_config.recompute_checkpoint_every_layer: transformations += [ RecomputationCheckpoint( - "vit.encoder.layer.[0-9]+", to_exclude=f"vit.encoder.layer.{self.config.num_layers - 1}" + "vit.encoder.layer.[0-9]+", to_exclude=f"vit.encoder.layer.{self.config.num_hidden_layers - 1}" ), ] return transformations From 3207fecf82b3f167866e65a6aef3646931975c1b Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 22 Aug 2022 12:04:34 +0200 Subject: [PATCH 10/33] Fix BART after rebasing --- .../graphcore/models/bart/modeling_bart.py | 125 ++++++++++-------- 1 file changed, 71 insertions(+), 54 deletions(-) diff --git a/optimum/graphcore/models/bart/modeling_bart.py b/optimum/graphcore/models/bart/modeling_bart.py index bf64f7407..a906cc376 100644 --- a/optimum/graphcore/models/bart/modeling_bart.py +++ b/optimum/graphcore/models/bart/modeling_bart.py @@ -14,10 +14,11 @@ from typing import Optional, Tuple import torch +from torch import nn import transformers from optimum.utils import logging -from transformers import BartForConditionalGeneration +from transformers import BartForConditionalGeneration, BartForSequenceClassification from transformers.models.bart.modeling_bart import BartAttention from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, ReversibleTransformation, compose @@ -346,77 +347,93 @@ def forward( @register(BartForSequenceClassification) class PipelinedBartForSequenceClassification(BartForSequenceClassification, PipelineMixin): + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + transformations = [ + AddPoptorchBlock("Embedding", 0, "model.shared", log_insertions=log_insertions), + # AddPoptorchBlock("Embedding", 0, "model.encoder.embed_positions"), + # AddPoptorchBlock("Embedding", 0, "model.encoder.layernorm_embedding"), + # AddPoptorchBlock("Embedding", 0, "model.decoder.embed_positions"), + # AddPoptorchBlock("Embedding", 0, "model.decoder.layernorm_embedding"), + AddPoptorchBlocksInSeries( + "Encoder", + layer_ipu[: self.config.encoder_layers], + r"model.encoder.layers.[0-9]+", + log_insertions=log_insertions, + ), + AddPoptorchBlocksInSeries( + "Decoder", + layer_ipu[self.config.encoder_layers :], + r"model.decoder.layers.[0-9]+", + log_insertions=log_insertions, + ), + AddPoptorchBlock( + "Classification Head Output", layer_ipu[-1], "classification_head", log_insertions=log_insertions + ), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations += [ + RecomputationCheckpoint( + "model.encoder.layers.[0-9]+", to_exclude=f"model.encoder.layers.{self.config.encoder_layers - 1}" + ), + RecomputationCheckpoint( + "model.decoder.layers.[0-9]+", to_exclude=f"model.decoder.layers.{self.config.decoder_layers - 1}" + ), + ] + + if not isinstance(self, torch.fx.GraphModule): + if self.ipu_config.embedding_serialization_factor > 1: + transformations += [ + LinearToSerializedLinear("lm_head"), + TieWeights("model.shared", "lm_head"), + ] + transformations += [ShareEmbeddingComputation()] + return transformations + def parallelize(self): """ Transform the model to run in an IPU pipeline. - Adds pipeline stages to the model + - (If enabled) Replaces the shared embedding with a SerializedEmbedding - Adds recomputation checkpoints Recommended usage: ``` - model = PipelinedBartForSequenceClassification(config).parallelize().half() + model = PipelinedBartForConditionalGeneration(config).parallelize().half() ``` """ super().parallelize() - - self.model.__class__ = _BartModelWithSharedEmbedding - self.model.encoder_and_decoder_embeddings_computation(True) - self.model.change_bart_encoder_and_decoder_classes(False) - self.model.change_bart_attention_class(False) - - logger.info("-------------------- Device Allocation --------------------") - logger.info("Embedding --> IPU 0") - self.model.shared = poptorch.BeginBlock(self.model.shared, "Embedding", ipu_id=0) - self.model.encoder.embed_positions = poptorch.BeginBlock( - self.model.encoder.embed_positions, "Embedding", ipu_id=0 - ) - self.model.encoder.layernorm_embedding = poptorch.BeginBlock( - self.model.encoder.layernorm_embedding, "Embedding", ipu_id=0 - ) - - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - for index, layer in enumerate(self.model.encoder.layers): - ipu = layer_ipu[index] - if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: - recomputation_checkpoint(layer) - self.model.encoder.layers[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) - logger.info(f"Encoder {index:<2} --> IPU {ipu}") - - self.model.decoder.embed_positions = poptorch.BeginBlock( - self.model.decoder.embed_positions, "Embedding", ipu_id=0 - ) - self.model.decoder.layernorm_embedding = poptorch.BeginBlock( - self.model.decoder.layernorm_embedding, "Embedding", ipu_id=0 - ) - shift = len(self.model.encoder.layers) - for index, layer in enumerate(self.model.decoder.layers): - ipu = layer_ipu[index + shift] - if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: - recomputation_checkpoint(layer) - self.model.decoder.layers[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu) - logger.info(f"Decoder {index:<2} --> IPU {ipu}") - - last_ipu = len(self.ipu_config.layers_per_ipu) - 1 - logger.info(f"Classification Head Output --> IPU {last_ipu}") - self.classification_head = poptorch.BeginBlock( - self.classification_head, "Classification Head Output", ipu_id=last_ipu - ) - logger.info("-----------------------------------------------------------") - return self + orig_make_causal_mask = transformers.models.bart.modeling_bart._make_causal_mask + orig_expand_mask = transformers.models.bart.modeling_bart._expand_mask + transformers.models.bart.modeling_bart._make_causal_mask = _make_causal_mask + transformers.models.bart.modeling_bart._expand_mask = _expand_mask + for mod in self.modules(): + if isinstance(mod, BartAttention): + mod.__class__ = _BartAttentionWithoutException + traced = symbolic_trace_pipelined_model(self) + transformers.models.bart.modeling_bart._make_causal_mask = orig_make_causal_mask + transformers.models.bart.modeling_bart._expand_mask = orig_expand_mask + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced def deparallelize(self): """ Undo the changes to the model done by `parallelize`. You should call this before doing `save_pretrained` so that the `model.state_dict` is - fully compatible with `transformers.BartForSequenceClassification`. + fully compatible with `transformers.BartForConditionalGeneration`. """ super().deparallelize() - - self.model.encoder_and_decoder_embeddings_computation(False) - self.model.change_bart_encoder_and_decoder_classes(True) - self.model.change_bart_attention_class(True) - self.model.__class__ = BartModel - + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations = [t for t in transformations if isinstance(t, ReversibleTransformation)] + composition = compose(*transformations) + self = composition(self, reverse=True) return self def forward( From 2f4b9688223a928bb52c17b0e8ef7b0eddaf60f6 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 22 Aug 2022 14:41:47 +0200 Subject: [PATCH 11/33] Small fixes --- .../graphcore/models/hubert/modeling_hubert.py | 9 --------- .../models/roberta/modeling_roberta.py | 10 ---------- optimum/graphcore/models/t5/modeling_t5.py | 18 ------------------ 3 files changed, 37 deletions(-) diff --git a/optimum/graphcore/models/hubert/modeling_hubert.py b/optimum/graphcore/models/hubert/modeling_hubert.py index 577dad313..c337911ec 100644 --- a/optimum/graphcore/models/hubert/modeling_hubert.py +++ b/optimum/graphcore/models/hubert/modeling_hubert.py @@ -15,7 +15,6 @@ import poptorch from transformers import HubertForSequenceClassification -from transformers.models.hubert.modeling_hubert import HubertEncoder, HubertEncoderStableLayerNorm from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose from ....utils import logging @@ -88,11 +87,3 @@ def deparallelize(self): composition = compose(*transformations) self = composition(self, reverse=True) return self - - def deparallelize(self): - """ - Undo the changes to the model done by `parallelize`. - """ - super().deparallelize() - self.change_hubert_encoder_class(True) - return self diff --git a/optimum/graphcore/models/roberta/modeling_roberta.py b/optimum/graphcore/models/roberta/modeling_roberta.py index 86db73511..24d88f765 100644 --- a/optimum/graphcore/models/roberta/modeling_roberta.py +++ b/optimum/graphcore/models/roberta/modeling_roberta.py @@ -46,11 +46,7 @@ from ...modeling_utils import ( OnehotGather, PipelineMixin, - SerializedEmbedding, - SerializedLinear, get_layer_ipu, - outline_attribute, - recomputation_checkpoint, register, ) @@ -142,12 +138,6 @@ def __init__(self, config): super().__init__(config) self.gather_indices = OnehotGather() - # def get_ops_to_wrap_for_tracing(self): - # return [ - # ("torch.topk", *_gen_constructor_wrapper(torch.topk)), - # ("torch.nn.functional.one_hot", *_gen_constructor_wrapper(torch.nn.functional.one_hot)), - # ] - def get_transformations(self): log_insertions = self.ipu_config.log_insertions layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) diff --git a/optimum/graphcore/models/t5/modeling_t5.py b/optimum/graphcore/models/t5/modeling_t5.py index d6bb08974..8f63048f9 100644 --- a/optimum/graphcore/models/t5/modeling_t5.py +++ b/optimum/graphcore/models/t5/modeling_t5.py @@ -64,24 +64,6 @@ class PipelinedT5ForConditionalGeneration( def is_encoder_and_decoder_embeddings_computation_shared(self): return isinstance(self.shared, SharedEmbedding) - def encoder_and_decoder_embeddings_computation(self, use_shared_embedding: bool): - """Sets the T5ForConditionalGeneration shared embedding layer to SharedEmbedding that combines the computation under one layer. - - Args: - use_shared_embedding: whether to use SharedEmbedding or not. - """ - - if use_shared_embedding: - if isinstance(self.shared, SharedEmbedding): - logger.warning("encoder and decoder embeddings computation is already shared") - else: - self.shared = SharedEmbedding(self.shared) - else: - if isinstance(self.shared, nn.Embedding): - logger.warning("encoder and decoder embeddings computation is not shared") - else: - self.shared = self.shared.shared - def scale_down_weights(self, factor: float = 1, restore: bool = False): self.lm_scale_modifier = 1 if not restore else None # self.lm_scale_modifier = nn.Parameter(torch.ones(self.config.d_model, dtype=torch.float16)) if not restore else None From 28ff013bcd605ad8c0862fab0a4e639623eac4cf Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 26 Aug 2022 18:35:33 +0200 Subject: [PATCH 12/33] [WIP] working or almost working version for everyone --- examples/text-classification/run_glue.py | 19 +- .../models/convnext/modeling_convnext.py | 138 +++++++-- .../models/deberta/modeling_deberta.py | 182 +++++------ .../graphcore/models/gpt2/modeling_gpt2.py | 288 +++++++++--------- .../models/lxmert/modeling_lxmert.py | 136 +++++---- .../models/roberta/modeling_roberta.py | 9 +- .../models/wav2vec2/modeling_wav2vec2.py | 175 ++++++----- optimum/graphcore/trainer.py | 10 +- 8 files changed, 517 insertions(+), 440 deletions(-) diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index afb3d63be..1905e7209 100755 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -469,17 +469,14 @@ def preprocess_function(examples): max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) - # labels = torch.tensor(train_dataset[0]["label"]) - # if model.config.problem_type is None: - # if model.config.num_labels == 1: - # model.config.problem_type = "regression" - # elif model.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - # model.config.problem_type = "single_label_classification" - # else: - # model.config.problem_type = "multi_label_classification" - dummy_input = tokenizer("Used to set the model.config.problem_type", return_tensors="pt") - dummy_input["labels"] = torch.tensor(train_dataset[0]["label"]) - model(**dummy_input) + labels = torch.tensor(train_dataset[0]["label"]) + if model.config.problem_type is None: + if model.config.num_labels == 1: + model.config.problem_type = "regression" + elif model.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + model.config.problem_type = "single_label_classification" + else: + model.config.problem_type = "multi_label_classification" if training_args.do_eval: if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: diff --git a/optimum/graphcore/models/convnext/modeling_convnext.py b/optimum/graphcore/models/convnext/modeling_convnext.py index 65c7478ae..d0a9983a2 100644 --- a/optimum/graphcore/models/convnext/modeling_convnext.py +++ b/optimum/graphcore/models/convnext/modeling_convnext.py @@ -11,17 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import torch -import torch.nn as nn - import poptorch -from optimum.utils import logging -from transformers.models.convnext.modeling_convnext import ( - ConvNextForImageClassification, - ConvNextLayer, - ConvNextLayerNorm, +import transformers +from transformers.models.convnext.modeling_convnext import ConvNextLayer, ConvNextLayerNorm, ConvNextForImageClassification + +from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ....utils import logging +from ...fx.transformations import ( + AddPoptorchBlock, + AddPoptorchBlocksInSeries, + ClipValues, + ClipValuesSymmetric, + RecomputationCheckpoint, + TupleOutput, ) +from ...fx.utils import symbolic_trace_pipelined_model from ...modeling_utils import PipelineMixin, get_layer_ipu, register from .optimized_convnextlayer import OptimizedConvNextLayer @@ -29,7 +34,20 @@ logger = logging.get_logger(__name__) +_OPTIMIZATION_TRANSFORMATIONS = [ + ChangeTrueDivToMulByInverse(), + MergeLinears(), + # FuseBiasInLinear(), +] + +_NON_REVERSIBLE_TRANSFORMATIONS = [ + ClipValuesSymmetric(1e4, exclude_targets=["view"]), + ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), + TupleOutput(), +] + +<<<<<<< HEAD class IPUConvNextLayerNorm(nn.Module): r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, @@ -62,43 +80,60 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @register(ConvNextForImageClassification) class PipelinedConvNextForImageClassification(ConvNextForImageClassification, PipelineMixin): + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + transformations = [ + AddPoptorchBlock( + "Embedding", 0, r"convnext.embeddings", log_insertions=log_insertions + ), + AddPoptorchBlocksInSeries( + "Encoder", layer_ipu, r"convnext.encoder.stages.[0-9]+.layers.[0-9]+", log_insertions=log_insertions + ), + AddPoptorchBlock( + "LayerNorm", layer_ipu[-1], r"convnext.layernorm", log_insertions=log_insertions + ), + AddPoptorchBlock( + "Classifier", layer_ipu[-1], r"classifier", log_insertions=log_insertions + ), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations += [ + RecomputationCheckpoint( + "convnext.encoder.stages.[0-9]+.layers.[0-9]+", + to_exclude=f"convnext.encoder.stages.{self.config.num_stages - 1}.layers.{self.config.depths[-1] - 1}", + ), + ] + return transformations + def parallelize(self): super().parallelize() - # Use optimized ConvNextLayer - for stage in self.convnext.encoder.stages: - for layer in stage.layers: - layer.__class__ = OptimizedConvNextLayer + if not isinstance(self, torch.fx.GraphModule): + # Use optimized ConvNextLayer + for stage in self.convnext.encoder.stages: + for layer in stage.layers: + layer.__class__ = OptimizedConvNextLayer +<<<<<<< HEAD # Enable autocast for ConvNextLayerNorm because computation cannot happen in fp16 for mod in self.modules(): if isinstance(mod, ConvNextLayerNorm): mod.__class__ = IPUConvNextLayerNorm - logger.info("---------- Device Allocation -----------") - logger.info(f"Embedding --> IPU 0") - self.convnext.embeddings = poptorch.BeginBlock(self.convnext.embeddings, "Embedding", ipu_id=0) - - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - global_layer_idx = 0 - for stage_idx, stage in enumerate(self.convnext.encoder.stages): - for layer_idx, layer in enumerate(stage.layers): - ipu = layer_ipu[global_layer_idx] - logger.info(f"Encoder stage {stage_idx}, convnext layer {layer_idx} --> IPU {ipu}") - layer = poptorch.BeginBlock(layer, f"Encoder_stage_{stage_idx}_layer_{layer_idx}", ipu_id=ipu) - global_layer_idx += 1 - - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"Head --> IPU {last_ipu}") - logger.info("---------------------------------------") - self.convnext.layernorm = poptorch.BeginBlock(self.convnext.layernorm, "LayerNorm", ipu_id=last_ipu) - self.classifier = poptorch.BeginBlock(self.classifier, "Classifier", ipu_id=last_ipu) - - return self + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced def deparallelize(self): super().deparallelize() + # TODO: is that needed? for mod in self.modules(): if isinstance(mod, IPUConvNextLayerNorm): mod.__class__ = ConvNextLayerNorm @@ -107,3 +142,42 @@ def deparallelize(self): for stage in self.convnext.encoder.stages: for layer in stage.layers: layer.__class__ = ConvNextLayer + + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + self = composition(self, reverse=True) + return self + # def parallelize(self): + # super().parallelize() + + # # Use optimized ConvNextLayer + # for stage in self.convnext.encoder.stages: + # for layer in stage.layers: + # layer.__class__ = OptimizedConvNextLayer + + # # Enable autocast for ConvNextLayerNorm because computation cannot happen in fp16 + # for mod in self.modules(): + # if isinstance(mod, ConvNextLayerNorm): + # mod.forward = poptorch.autocast(enabled=True)(mod.forward) + + # logger.info("---------- Device Allocation -----------") + # logger.info(f"Embedding --> IPU 0") + # self.convnext.embeddings = poptorch.BeginBlock(self.convnext.embeddings, "Embedding", ipu_id=0) + + # layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + # global_layer_idx = 0 + # for stage_idx, stage in enumerate(self.convnext.encoder.stages): + # for layer_idx, layer in enumerate(stage.layers): + # ipu = layer_ipu[global_layer_idx] + # logger.info(f"Encoder stage {stage_idx}, convnext layer {layer_idx} --> IPU {ipu}") + # layer = poptorch.BeginBlock(layer, f"Encoder_stage_{stage_idx}_layer_{layer_idx}", ipu_id=ipu) + # global_layer_idx += 1 + + # last_ipu = self.ipu_config.ipus_per_replica - 1 + # logger.info(f"Head --> IPU {last_ipu}") + # logger.info("---------------------------------------") + # self.convnext.layernorm = poptorch.BeginBlock(self.convnext.layernorm, "LayerNorm", ipu_id=last_ipu) + # self.classifier = poptorch.BeginBlock(self.classifier, "Classifier", ipu_id=last_ipu) + + # return self diff --git a/optimum/graphcore/models/deberta/modeling_deberta.py b/optimum/graphcore/models/deberta/modeling_deberta.py index c1dc1b0f8..412b0bb00 100644 --- a/optimum/graphcore/models/deberta/modeling_deberta.py +++ b/optimum/graphcore/models/deberta/modeling_deberta.py @@ -20,14 +20,7 @@ import torch.nn.functional as F import poptorch -from optimum.utils import logging -from transformers import ( - DebertaForMaskedLM, - DebertaForQuestionAnswering, - DebertaForSequenceClassification, - DebertaForTokenClassification, -) -from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput +from transformers import DebertaForQuestionAnswering, DebertaForSequenceClassification, DebertaForTokenClassification from transformers.models.deberta.modeling_deberta import ( DebertaEncoder, DisentangledSelfAttention, @@ -35,20 +28,36 @@ build_relative_position, ) -from ...modeling_utils import ( - OnehotGather, - PipelineMixin, - SerializedEmbedding, - SerializedLinear, - get_layer_ipu, - outline_attribute, - recomputation_checkpoint, - register, +from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ....utils import logging +from ...fx.transformations import ( + AddPoptorchBlock, + AddPoptorchBlocksInSeries, + ClipValues, + ClipValuesSymmetric, + OutlineAttribute, + RecomputationCheckpoint, + TupleOutput, + VocabEmbeddingToSerializedEmbedding, ) +from ...fx.utils import symbolic_trace_pipelined_model +from ...modeling_utils import PipelineMixin, get_layer_ipu, register logger = logging.get_logger(__name__) +_OPTIMIZATION_TRANSFORMATIONS = [ + # ChangeTrueDivToMulByInverse(), + MergeLinears(), + # FuseBiasInLinear(), +] + +_NON_REVERSIBLE_TRANSFORMATIONS = [ + ClipValuesSymmetric(1e4, exclude_targets=["view"]), + ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), + TupleOutput(), +] + class FastGatherLastDim(nn.Module): """ @@ -207,7 +216,7 @@ def linear(w, b, x): context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (-1,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) if output_attentions: return (context_layer, attention_probs) else: @@ -238,7 +247,8 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd pos_key_layer = self.pos_proj(rel_embeddings) pos_key_layer = self.transpose_for_scores(pos_key_layer) c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2)) - c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + # c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_pos = (relative_pos + att_span).clamp(0, att_span * 2 - 1) index = c2p_pos.expand( [query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)] ) @@ -292,58 +302,48 @@ def change_modules_for_ipu(self, restore: bool): func = DebertaEncoder.get_rel_embedding if restore else _get_rel_embedding mod.get_rel_embedding = func.__get__(mod, DebertaEncoder) + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + transformations = [ + AddPoptorchBlock("Embedding", 0, "deberta.embeddings", log_insertions=log_insertions), + OutlineAttribute("deberta.embeddings.LayerNorm", "Embedding"), + AddPoptorchBlock("Before Encoder", 0, "deberta.encoder", log_insertions=log_insertions), + AddPoptorchBlocksInSeries( + "Encoder", layer_ipu, r"deberta.encoder.layer.[0-9]+", log_insertions=log_insertions + ), + # Only one of the following AddPoptorchBlock, will actually add a block. + AddPoptorchBlock("Classifier Output", layer_ipu[-1], "classifier", log_insertions=log_insertions), + AddPoptorchBlock("QA Outputs", layer_ipu[-1], "qa_outputs", log_insertions=log_insertions), + ] + # if self.ipu_config.recompute_checkpoint_every_layer: + # transformations.append( + # RecomputationCheckpoint( + # "deberta.encoder.layer.[0-9]+", + # to_exclude=f"deberta.encoder.layer.{self.config.num_hidden_layers - 1}", + # ) + # ) + if self.ipu_config.embedding_serialization_factor > 1: + transformations.append(VocabEmbeddingToSerializedEmbedding()) + return transformations + def parallelize(self): """ Transform the model to run in an IPU pipeline. - Adds pipeline stages to the model - - (If enabled) Replaces the word embedding with a SerializedEmbedding - - Replaces several modules with IPU compatible counterparts + - (If enabled) Replaces the word embedding projection with a SerializedLinear layer - Adds recomputation checkpoints """ - self._hooks = [] - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - - logger.info("-------------------- Device Allocation --------------------") - logger.info("Embedding --> IPU 0") - if self.ipu_config.embedding_serialization_factor > 1: - if isinstance(self, PipelinedDebertaForMaskedLM): - serialized_decoder = SerializedLinear( - self.config.hidden_size, - self.config.vocab_size, - self.ipu_config.embedding_serialization_factor, - bias=True, - mode=poptorch.MatMulSerializationMode.OutputChannels, - ) - serialized_decoder.load_state_dict(self.cls.predictions.decoder.state_dict()) - self.cls.predictions.decoder = serialized_decoder - self.tie_weights() - else: - self.deberta.embeddings.word_embeddings = SerializedEmbedding( - self.deberta.embeddings.word_embeddings, self.ipu_config.embedding_serialization_factor - ) - + super().parallelize() self.change_modules_for_ipu(False) - - self.deberta.embeddings = poptorch.BeginBlock(self.deberta.embeddings, "Embedding", ipu_id=0) - hs = outline_attribute(self.deberta.embeddings.LayerNorm, "embedding") - self._hooks.extend(hs) - - self.deberta.encoder = poptorch.BeginBlock(self.deberta.encoder, ipu_id=0) - if self.deberta.encoder.relative_attention: - self.deberta.encoder.rel_embeddings = poptorch.BeginBlock(self.deberta.encoder.rel_embeddings, ipu_id=0) - - for index, layer in enumerate(self.deberta.encoder.layer): - ipu = layer_ipu[index] - if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: - h = recomputation_checkpoint(layer) - self._hooks.append(h) - self.deberta.encoder.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) - logger.info(f"Encoder {index:<2} --> IPU {ipu}") - - if isinstance(self, PipelinedDebertaForMaskedLM): - logger.info(f"Projection {index:<2} --> IPU {0}") - self.cls.predictions.decoder = poptorch.BeginBlock(self.cls.predictions.decoder, "Projection", ipu_id=0) - return self + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations, inplace=True) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced def deparallelize(self): """ @@ -353,19 +353,10 @@ def deparallelize(self): """ super().deparallelize() self.change_modules_for_ipu(True) - if self.ipu_config.embedding_serialization_factor > 1: - if isinstance(self, PipelinedDebertaForMaskedLM): - decoder = nn.Linear( - self.config.hidden_size, - self.config.vocab_size, - bias=True, - ) - decoder.load_state_dict(self.cls.predictions.decoder.state_dict()) - self.cls.predictions.decoder = decoder - self.tie_weights() - else: - # Deserialize the serialized word embedding - self.deberta.embeddings.word_embeddings = self.deberta.embeddings.word_embeddings.deserialize() + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + self = composition(self, reverse=True) return self @@ -457,14 +448,8 @@ class PipelinedDebertaForSequenceClassification(DebertaForSequenceClassification ``` """ - def parallelize(self): - super().parallelize() - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"Classifier Output --> IPU {last_ipu}") - self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self - + # def forward(self, input_ids, attention_mask, token_type_ids, labels=None): + # return_dict = False @register(DebertaForTokenClassification) class PipelinedDebertaForTokenClassification(DebertaForTokenClassification, DebertaPipelineMixin): @@ -477,24 +462,6 @@ class PipelinedDebertaForTokenClassification(DebertaForTokenClassification, Debe ``` """ - def parallelize(self): - super().parallelize() - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"Classifier Output --> IPU {last_ipu}") - self.classifier = poptorch.BeginBlock(self.classifier, "Classifier Output", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self - - def deparallelize(self): - super().deparallelize() - # Last dropout isn't a StableDropout so undo its replacement - # made by change_modules_for_ipu - mod = self.dropout - if isinstance(mod, StableDropout): - mod.__class__ = nn.Dropout - mod.p = mod.drop_prob - mod.inplace = False - @register(DebertaForQuestionAnswering) class PipelinedDebertaForQuestionAnswering(DebertaForQuestionAnswering, DebertaPipelineMixin): @@ -507,14 +474,6 @@ class PipelinedDebertaForQuestionAnswering(DebertaForQuestionAnswering, DebertaP ``` """ - def parallelize(self): - super().parallelize() - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"QA Outputs --> IPU {last_ipu}") - self.qa_outputs = poptorch.BeginBlock(self.qa_outputs, "QA Outputs", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self - def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -539,7 +498,6 @@ def forward( are not taken into account for computing the loss. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # return_dict = False output = super().forward( input_ids, diff --git a/optimum/graphcore/models/gpt2/modeling_gpt2.py b/optimum/graphcore/models/gpt2/modeling_gpt2.py index a602909f6..ecd649ca7 100644 --- a/optimum/graphcore/models/gpt2/modeling_gpt2.py +++ b/optimum/graphcore/models/gpt2/modeling_gpt2.py @@ -19,180 +19,188 @@ import torch.nn as nn import poptorch -from optimum.utils import logging from transformers import GPT2ForSequenceClassification, GPT2ForTokenClassification, GPT2LMHeadModel from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast from transformers.models.gpt2.modeling_gpt2 import GPT2Attention -from ...modeling_utils import ( - PipelineMixin, - SerializedEmbedding, - SerializedLinear, - get_layer_ipu, - outline_attribute, - recomputation_checkpoint, - register, +from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, ReversibleTransformation, compose +from ....utils import logging +from ...fx.transformations import ( + AddPoptorchBlock, + AddPoptorchBlocksInSeries, + ClipValues, + ClipValuesSymmetric, + LinearToSerializedLinear, + OutlineAttribute, + RecomputationCheckpoint, + TieWeights, + TupleOutput, + VocabEmbeddingToSerializedEmbedding, ) +from ...fx.utils import symbolic_trace_pipelined_model +from ...modeling_utils import PipelineMixin, get_layer_ipu, register from .optimized_gpt2_attn import OptimizedGPT2Attention logger = logging.get_logger(__name__) +_OPTIMIZATION_TRANSFORMATIONS = [ + ChangeTrueDivToMulByInverse(), + MergeLinears(), + # FuseBiasInLinear(), +] + +_NON_REVERSIBLE_TRANSFORMATIONS = [ + ClipValuesSymmetric(1e4, exclude_targets=["view"]), + ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), + TupleOutput(), +] + class GPT2PipelineMixin(PipelineMixin): + @property + def actual_vocab_size(self): + return self.config.vocab_size + + @property + def new_vocab_size(self): + new_vocab_size = ( + math.ceil(self.config.vocab_size / self.ipu_config.embedding_serialization_factor) + * self.ipu_config.embedding_serialization_factor + ) + return new_vocab_size + + def resize_vocab(self, restore: bool): + if restore: + # Resize token embeddings back to origianl vocab_size + if self.config.vocab_size > self.actual_vocab_size: + self.resize_token_embeddings(self.actual_vocab_size) + else: + if self.new_vocab_size > self.actual_vocab_size: + self.resize_token_embeddings(self.new_vocab_size) + + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + transformations = [ + AddPoptorchBlock("Token Embedding", 0, "transformer.wte", log_insertions=log_insertions), + AddPoptorchBlock("Position Embedding", 1, "transformer.wtp", log_insertions=log_insertions), + OutlineAttribute("transformer.ln_f", "LayerNorm"), + AddPoptorchBlocksInSeries("Layer", layer_ipu, r"transformer.h.[0-9]+", log_insertions=log_insertions), + # Only one of the following AddPoptorchBlock, will actually add a block. + AddPoptorchBlock("Score", layer_ipu[-1], "score", log_insertions=log_insertions), + AddPoptorchBlock("Classifier", layer_ipu[-1], "classifier", log_insertions=log_insertions), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations.append( + RecomputationCheckpoint( + "transformer.h.[0-9]+", + to_exclude=f"transformer.h.{self.config.num_hidden_layers - 1}", + ) + ) + if self.ipu_config.embedding_serialization_factor > 1: + self.resize_vocab(False) + transformations.append(VocabEmbeddingToSerializedEmbedding()) + + return transformations + def parallelize(self): """ - Transform the GPT2 model body to run in an IPU pipeline. + Transform the GPT-2 model body to run in an IPU pipeline. - Adds pipeline stages to the model - (If enabled) Replaces the word embedding with a SerializedEmbedding - Adds recomputation checkpoints """ - super().parallelize() - - # Use optimized attention - for layer in self.transformer.h: - layer.attn.__class__ = OptimizedGPT2Attention - + PipelineMixin.parallelize(self) if self.ipu_config.embedding_serialization_factor > 1: - # Resize token embedding using padding if vocab_size is not a multiple of embedding_serialization_factor. - self.actual_vocab_size = self.config.vocab_size - new_vocab_size = ( - math.ceil(self.config.vocab_size / self.ipu_config.embedding_serialization_factor) - * self.ipu_config.embedding_serialization_factor - ) - if new_vocab_size > self.actual_vocab_size: - self.resize_token_embeddings(new_vocab_size) - - self.transformer.wte = SerializedEmbedding( - self.transformer.wte, self.ipu_config.embedding_serialization_factor - ) - - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - - logger.info("-------------------- Device Allocation --------------------") - logger.info("Embedding --> IPU 0") - self.transformer.wte = poptorch.BeginBlock(self.transformer.wte, "Token embedding", ipu_id=0) - self.transformer.wpe = poptorch.BeginBlock(self.transformer.wpe, "Position embedding", ipu_id=0) - hs = outline_attribute(self.transformer.ln_f, "LayerNorm") - self._hooks.extend(hs) - - for index, layer in enumerate(self.transformer.h): - ipu = layer_ipu[index] - if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: - h = recomputation_checkpoint(layer) - self._hooks.append(h) - self.transformer.h[index] = poptorch.BeginBlock(layer, f"Layer{index}", ipu_id=ipu) - logger.info(f"Layer {index:<2} --> IPU {ipu}") - return self + self.resize_vocab(False) + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced def deparallelize(self): """ Undo the changes to the model done by `parallelize`. You should call this before doing `save_pretrained` so that the `model.state_dict` is - fully compatible with `transformers` models. + fully compatible with the original model. """ super().deparallelize() - + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations = [t for t in transformations if isinstance(t, ReversibleTransformation)] + composition = compose(*transformations) + self = composition(self, reverse=True) if self.ipu_config.embedding_serialization_factor > 1: - # Deserialize the serialized word embedding - self.transformer.wte = self.transformer.wte.deserialize() - - # Resize token embeddings back to origianl vocab_size - if self.config.vocab_size > self.actual_vocab_size: - self.resize_token_embeddings(self.actual_vocab_size) - - # Switch back to non-optimized attention - for layer in self.transformer.h: - layer.attn.__class__ = GPT2Attention + self.resize_vocab(True) return self @register(GPT2LMHeadModel) -class PipelinedGPT2LMHeadModel(GPT2LMHeadModel, PipelineMixin): +class PipelinedGPT2LMHeadModel(GPT2LMHeadModel, GPT2PipelineMixin): + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + transformations = [ + AddPoptorchBlock("Token Embedding", 0, "transformer.wte", log_insertions=log_insertions), + AddPoptorchBlock("Position Embedding", 1, "transformer.wtp", log_insertions=log_insertions), + OutlineAttribute("transformer.ln_f", "LayerNorm"), + AddPoptorchBlocksInSeries("Layer", layer_ipu, r"transformer.h.[0-9]+", log_insertions=log_insertions), + AddPoptorchBlock("LM Head", 0, "lm_head", log_insertions=log_insertions), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations.append( + RecomputationCheckpoint( + "transformer.h.[0-9]+", + to_exclude=f"transformer.h.{self.config.num_hidden_layers - 1}", + ) + ) + if self.ipu_config.embedding_serialization_factor > 1: + self.resize_vocab(False) + transformations += [ + LinearToSerializedLinear("lm_head"), + TieWeights("transformer.wte", "lm_head"), + ] + + return transformations + def parallelize(self): """ - Transform the model to run in an IPU pipeline. + Transform the Roberta model body to run in an IPU pipeline. - Adds pipeline stages to the model + - (If enabled) Replaces the word embedding with a SerializedEmbedding - Adds recomputation checkpoints - - Recommended usage: - ``` - model = PipelinedGPT2LMHeadModel(config).parallelize().half() - ``` """ PipelineMixin.parallelize(self) - - # Use optimized attention - for layer in self.transformer.h: - layer.attn.__class__ = OptimizedGPT2Attention - if self.ipu_config.embedding_serialization_factor > 1: - # Resize token embedding using padding if vocab_size is not a multiple of embedding_serialization_factor. - self.actual_vocab_size = self.config.vocab_size - new_vocab_size = ( - math.ceil(self.config.vocab_size / self.ipu_config.embedding_serialization_factor) - * self.ipu_config.embedding_serialization_factor - ) - if new_vocab_size > self.actual_vocab_size: - # There is a tie_weights operation in resize_token_embeddings so the lm_head's weight is also resized. - self.resize_token_embeddings(new_vocab_size) - - serialized_lm_head = SerializedLinear( - self.config.n_embd, - self.config.vocab_size, # Note that if padding is done, self.config.vocab_size == new_vocab_size - self.ipu_config.embedding_serialization_factor, - bias=False, - mode=poptorch.MatMulSerializationMode.OutputChannels, - ) - serialized_lm_head.load_state_dict(self.lm_head.state_dict()) - self.lm_head = serialized_lm_head - self.tie_weights() - - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - - logger.info("-------------------- Device Allocation --------------------") - logger.info("Token Embedding --> IPU 0") - self.transformer.wte = poptorch.BeginBlock(self.transformer.wte, "Token embedding", ipu_id=0) - logger.info("Position Embedding --> IPU 1") - self.transformer.wpe = poptorch.BeginBlock(self.transformer.wpe, "Position embedding", ipu_id=1) - hs = outline_attribute(self.transformer.ln_f, "LayerNorm") - self._hooks.extend(hs) - - for index, layer in enumerate(self.transformer.h): - ipu = layer_ipu[index] - if self.ipu_config.recompute_checkpoint_every_layer: - h = recomputation_checkpoint(layer) - self._hooks.append(h) - self.transformer.h[index] = poptorch.BeginBlock(layer, f"Layer{index}", ipu_id=ipu) - logger.info(f"Layer {index:<2} --> IPU {ipu}") - - logger.info("Head --> IPU 0") - self.lm_head = poptorch.BeginBlock(self.lm_head, "LM head", ipu_id=0) - logger.info("-----------------------------------------------------------") - return self + self.resize_vocab(False) + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced def deparallelize(self): + """ + Undo the changes to the model done by `parallelize`. + You should call this before doing `save_pretrained` so that the `model.state_dict` is + fully compatible with the original model. + """ PipelineMixin.deparallelize(self) - + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations = [t for t in transformations if isinstance(t, ReversibleTransformation)] + composition = compose(*transformations) + self = composition(self, reverse=True) if self.ipu_config.embedding_serialization_factor > 1: - # Deserialize the serialized linear layer - old_lm_head = nn.Linear( - self.config.n_embd, - self.config.vocab_size, # Note that if padding is done, self.config.vocab_size == new_vocab_size - bias=False, - ) - old_lm_head.load_state_dict(self.lm_head.state_dict()) - self.lm_head = old_lm_head - self.tie_weights() - - # Resize token embeddings back to origianl vocab_size. - # There is a tie_weights operation in resize_token_embeddings so the lm_head's weight is also resized. - if self.config.vocab_size > self.actual_vocab_size: - self.resize_token_embeddings(self.actual_vocab_size) - - # Switch back to non-optimized attention - for layer in self.transformer.h: - layer.attn.__class__ = GPT2Attention + self.resize_vocab(True) return self def forward( @@ -280,14 +288,6 @@ def forward( @register(GPT2ForSequenceClassification) class PipelinedGPT2ForSequenceClassification(GPT2ForSequenceClassification, GPT2PipelineMixin): - def parallelize(self): - super().parallelize() - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"Head --> IPU {last_ipu}") - self.score = poptorch.BeginBlock(self.score, "Score", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self - def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -337,10 +337,4 @@ def forward( @register(GPT2ForTokenClassification) class PipelinedGPT2ForTokenClassification(GPT2ForTokenClassification, GPT2PipelineMixin): - def parallelize(self): - super().parallelize() - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"Head --> IPU {last_ipu}") - self.classifier = poptorch.BeginBlock(self.classifier, "Classifier", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self + pass diff --git a/optimum/graphcore/models/lxmert/modeling_lxmert.py b/optimum/graphcore/models/lxmert/modeling_lxmert.py index a186fcb41..65c1cb4a0 100644 --- a/optimum/graphcore/models/lxmert/modeling_lxmert.py +++ b/optimum/graphcore/models/lxmert/modeling_lxmert.py @@ -11,73 +11,99 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Optional, Tuple, Union - import torch import torch.nn.functional as F import poptorch -from optimum.utils import logging -from transformers import LxmertForQuestionAnswering -from transformers.models.lxmert.modeling_lxmert import LxmertForQuestionAnsweringOutput - -from ...modeling_utils import PipelineMixin, recomputation_checkpoint, register +import transformers + +from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ....utils import logging +from ...fx.transformations import ( + AddPoptorchBlock, + AddPoptorchBlocksInSeries, + ClipValues, + ClipValuesSymmetric, + RecomputationCheckpoint, + TupleOutput, +) +from ...fx.utils import symbolic_trace_pipelined_model +from ...modeling_utils import PipelineMixin, get_layer_ipu, register logger = logging.get_logger(__name__) +_OPTIMIZATION_TRANSFORMATIONS = [ + ChangeTrueDivToMulByInverse(), + MergeLinears(), + # FuseBiasInLinear(), +] + +_NON_REVERSIBLE_TRANSFORMATIONS = [ + ClipValuesSymmetric(1e4, exclude_targets=["view"]), + ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), + TupleOutput(), +] + @register(LxmertForQuestionAnswering) class PipelinedLxmertForQuestionAnswering(LxmertForQuestionAnswering, PipelineMixin): + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + language_layers_ipus = layer_ipu[: self.config.l_layers] + visual_layers_ipus = layer_ipu[self.config.l_layers : self.config.l_layers + self.r_layers] + cross_modality_layers_ipus = layer_ipu[self.config.l_layers + self.r_layers :] + + transformations = [ + AddPoptorchBlock("Embedding", 0, "lxmert.embeddings", log_insertions=log_insertions), + AddPoptorchBlock("Image Embedding", 0, "lxmert.encoder.visn_fc", log_insertions=log_insertions), + AddPoptorchBlocksInSeries( + "Language Layer", language_layers_ipus, r"lxmert.encoder.layer.[0-9]+", log_insertions=log_insertions + ), + AddPoptorchBlocksInSeries( + "Visual Layer", visual_layers_ipus, r"lxmert.encoder.r_layers.[0-9]+", log_insertions=log_insertions + ), + AddPoptorchBlocksInSeries( + "Cross Modality Layer", + cross_modality_layers_ipus, + r"lxmert.encoder.x_layers.[0-9]+", + log_insertions=log_insertions, + ), + AddPoptorchBlock("Pooler Output", layer_ipu[-1], "lxmert.pooler", log_insertions=log_insertions), + AddPoptorchBlock("Head Output", layer_ipu[-1], "lxmert.answer_head", log_insertions=log_insertions), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations += [ + RecomputationCheckpoint( + "lxmert.encoder.layer.[0-9]+", to_exclude=f"lxmert.encoder.layer.{self.config.l_layers - 1}" + ), + RecomputationCheckpoint( + "lxmert.encoder.r_layers.[0-9]+", to_exclude=f"lxmert.encoder.r_layers.{self.config.r_layers - 1}" + ), + RecomputationCheckpoint( + "lxmert.encoder.x_layers.[0-9]+", to_exclude=f"lxmert.encoder.x_layers.{self.config.x_layers - 1}" + ), + ] + return transformations + def parallelize(self): - """ - Transform the model to run in an IPU pipeline. - - Adds pipeline stages to the model - - Adds recomputation checkpoints - Recommended usage: - ``` - model = PipelinedLxmertForQuestionAnswering(config).parallelize().half() - ``` - """ - self._hooks = [] - - logger.info("-------------------- Device Allocation --------------------") - logger.info("Embedding --> IPU 0") - self.lxmert.embeddings = poptorch.BeginBlock(self.lxmert.embeddings, "Embedding", ipu_id=0) - logger.info("Image embedding --> IPU 0") - self.lxmert.encoder.visn_fc = poptorch.BeginBlock(self.lxmert.encoder.visn_fc, "Image embedding", ipu_id=0) - - # Language layers - for index, layer in enumerate(self.lxmert.encoder.layer): - if self.ipu_config.recompute_checkpoint_every_layer: - h = recomputation_checkpoint(layer) - self._hooks.append(h) - self.lxmert.encoder.layer[index] = poptorch.BeginBlock(layer, f"Language layer{index}", ipu_id=1) - logger.info(f"Language layer {index:<2} --> IPU 1") - - # Visual layers - for index, layer in enumerate(self.lxmert.encoder.r_layers): - if self.ipu_config.recompute_checkpoint_every_layer: - h = recomputation_checkpoint(layer) - self._hooks.append(h) - self.lxmert.encoder.r_layers[index] = poptorch.BeginBlock(layer, f"Visual layer{index}", ipu_id=2) - logger.info(f"Visual layer {index:<2} --> IPU 2") - - # Cross modality layers - for index, layer in enumerate(self.lxmert.encoder.x_layers): - if self.ipu_config.recompute_checkpoint_every_layer: - h = recomputation_checkpoint(layer) - self._hooks.append(h) - self.lxmert.encoder.x_layers[index] = poptorch.BeginBlock(layer, f"Cross modality layer{index}", ipu_id=3) - logger.info(f"Cross modality layer {index:<2} --> IPU 3") - - logger.info(f"Pooler --> IPU 3") - self.lxmert.pooler = poptorch.BeginBlock(self.lxmert.pooler, "Pooler", ipu_id=3) - - logger.info(f"Head --> IPU 3") - self.answer_head = poptorch.BeginBlock(self.answer_head, "Head", ipu_id=3) - logger.info("-----------------------------------------------------------") + super().parallelize() + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced + + def deparallelize(self): + super().deparallelize() + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + self = composition(self, reverse=True) return self def forward( diff --git a/optimum/graphcore/models/roberta/modeling_roberta.py b/optimum/graphcore/models/roberta/modeling_roberta.py index 24d88f765..53235c9ca 100644 --- a/optimum/graphcore/models/roberta/modeling_roberta.py +++ b/optimum/graphcore/models/roberta/modeling_roberta.py @@ -19,7 +19,6 @@ from torch.nn import CrossEntropyLoss import poptorch -from optimum.utils import logging from transformers import ( RobertaForMaskedLM, RobertaForMultipleChoice, @@ -30,6 +29,7 @@ from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ....utils import logging from ...fx.transformations import ( AddPoptorchBlock, AddPoptorchBlocksInSeries, @@ -43,12 +43,7 @@ VocabEmbeddingToSerializedEmbedding, ) from ...fx.utils import symbolic_trace_pipelined_model -from ...modeling_utils import ( - OnehotGather, - PipelineMixin, - get_layer_ipu, - register, -) +from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register logger = logging.get_logger(__name__) diff --git a/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py b/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py index 24f529395..5111f2769 100755 --- a/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py @@ -19,7 +19,6 @@ import torch.nn.functional as F import poptorch -from optimum.utils import logging from transformers import Wav2Vec2ForPreTraining, Wav2Vec2Model from transformers.modeling_outputs import CausalLMOutput from transformers.models.wav2vec2.modeling_wav2vec2 import ( @@ -27,17 +26,40 @@ Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2ForCTC, - Wav2Vec2ForPreTrainingOutput, Wav2Vec2GumbelVectorQuantizer, ) -from ...modeling_utils import PipelineMixin, get_layer_ipu, recomputation_checkpoint, register +from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ....utils import logging +from ...fx.transformations import ( + AddPoptorchBlock, + AddPoptorchBlocksInSeries, + ClipValues, + ClipValuesSymmetric, + RecomputationCheckpoint, + TupleOutput, +) +from ...fx.utils import symbolic_trace_pipelined_model +from ...modeling_utils import PipelineMixin, get_layer_ipu, register + from .ipu_gumbel_vector_quantizer import IPUWav2Vec2GumbelVectorQuantizer from .ipu_layer_drop import IPUWav2Vec2Adapter, IPUWav2Vec2Encoder, IPUWav2Vec2EncoderStableLayerNorm logger = logging.get_logger(__name__) +_OPTIMIZATION_TRANSFORMATIONS = [ + ChangeTrueDivToMulByInverse(), + MergeLinears(), + # FuseBiasInLinear(), +] + +_NON_REVERSIBLE_TRANSFORMATIONS = [ + ClipValuesSymmetric(1e4, exclude_targets=["view"]), + ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), + TupleOutput(), +] + class IPUWav2Vec2Model(Wav2Vec2Model): def _get_feature_vector_attention_mask( @@ -73,8 +95,51 @@ def _get_feature_vector_attention_mask( return attention_mask +class Wav2Vec2PipelineMixin(PipelineMixin): + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + feature_extractor_conv_layers_ipu = layer_ipu[:self.config.num_feat_extract_layers] + transformations = [ + AddPoptorchBlocksInSeries( + "Conv", feature_extractor_conv_layers_ipu, r"wav2vec2.feature_extractor.conv_layers.[0-9]+", log_insertions=log_insertions + ), + AddPoptorchBlock("Positional Embedding", layer_ipu[self.config.num_feat_extract_layers], "wav2vec2.encoder.pos_conv_embed", log_insertions=log_insertions), + AddPoptorchBlocksInSeries( + "Encoder", layer_ipu[self.config.num_feat_extract_layers + 1:], r"wav2vec2.encoder.layers.[0-9]+", log_insertions=log_insertions + ), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations += [ + RecomputationCheckpoint( + "wav2vec2.encoder.layers.[0-9]+", + to_exclude=f"wav2vec2.encoder.layers.{self.config.num_hidden_layers - 1}", + ), + ] + return transformations + + def parallelize(self): + super().parallelize() + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced + + def deparallelize(self): + super().deparallelize() + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + self = composition(self, reverse=True) + return self + + @register(Wav2Vec2ForPreTraining) -class PipelinedWav2Vec2ForPreTraining(Wav2Vec2ForPreTraining, PipelineMixin): +class PipelinedWav2Vec2ForPreTraining(Wav2Vec2ForPreTraining, Wav2Vec2PipelineMixin): def change_wav2vec2_encoder_class(self, restore: bool): """Changes the encoder class to update its forward pass so that it uses our custom version. @@ -125,8 +190,17 @@ def change_conv_eps(self, restore: bool): self.original_eps.append(conv_layer.layer_norm.eps) conv_layer.layer_norm.eps = eps - def _add_begin_block(self, module, name, ipu_id): - poptorch.BeginBlock(module, name, ipu_id) + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + transformations = super().get_transformations() + start_idx = self.config.num_feat_extract_layers + 2 + transformations += [ + AddPoptorchBlock("Project Hidden", layer_ipu[start_idx], "project_hid", log_insertions=log_insertions), + AddPoptorchBlock("Quantizer", layer_ipu[start_idx + 1], "quantizer", log_insertions=log_insertions), + AddPoptorchBlock("Project Quantizer", layer_ipu[start_idx + 2], "project_q", log_insertions=log_insertions), + ] + return transformations def parallelize(self): """ @@ -140,41 +214,12 @@ def parallelize(self): model = PipelinedWav2Vec2ForPreTraining(config).parallelize().half() ``` """ - super().parallelize() - self.wav2vec2.__class__ = IPUWav2Vec2Model self.change_wav2vec2_encoder_class(False) self.change_wav2vec2_adapter_class(False) self.change_quantizer_class(False) self.change_conv_eps(False) - - logger.info("---------- Device Allocation -----------") - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - layers = [] - # Conv layers - for index, layer in enumerate(self.wav2vec2.feature_extractor.conv_layers): - layers.append((f"Conv {index:<2}", layer)) - # Positional Embedding - layers.append(("Positional Embedding", self.wav2vec2.encoder.pos_conv_embed)) - # Encoder layers - for index, layer in enumerate(self.wav2vec2.encoder.layers): - recomputation_checkpoint(layer) - layers.append((f"Encoder {index:<2}", layer)) - # Project Hidden - layers.append(("Project Hidden", self.project_hid)) - # Quantizer - layers.append(("Quantizer", self.quantizer)) - # Project Quantizer - layers.append(("Project Quantizer", self.project_q)) - - if len(layer_ipu) != len(layers): - raise ValueError(f"Layers per IPU total ({len(layer_ipu)}) must be equal to layers ({len(layers)}).") - - for i, (name, layer) in enumerate(layers): - logger.info(f"{name} --> IPU {layer_ipu[i]}") - self._add_begin_block(layer, name, ipu_id=layer_ipu[i]) - - logger.info("---------------------------------------") + return super().parallelize() def deparallelize(self): """ @@ -182,13 +227,12 @@ def deparallelize(self): You should call this before doing `save_pretrained` so that the `model.state_dict` is fully compatible with `transformers.Wav2Vec2ForPreTraining`. """ - super().deparallelize() self.change_wav2vec2_encoder_class(True) self.change_wav2vec2_adapter_class(True) self.change_quantizer_class(True) self.change_conv_eps(True) - self.wav2vec2.__class__ = Wav2Vec2Model - return self + self.wav2vec2.__class__ = IPUWav2Vec2Model + return super().deparallelize() def forward( self, @@ -368,7 +412,7 @@ def compute_contrastive_logits( @register(Wav2Vec2ForCTC) -class PipelinedWav2Vec2ForCTC(Wav2Vec2ForCTC, PipelineMixin): +class PipelinedWav2Vec2ForCTC(Wav2Vec2ForCTC, Wav2Vec2PipelineMixin): def change_wav2vec2_encoder_class(self, restore: bool): """Changes the encoder class to update its forward pass so that it uses our custom version. @@ -414,6 +458,17 @@ def change_conv_eps(self, restore: bool): def _add_begin_block(self, module, name, ipu_id): poptorch.BeginBlock(module, name, ipu_id) + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + print(layer_ipu) + transformations = super().get_transformations() + start_idx = self.config.num_feat_extract_layers + 2 + transformations.append( + AddPoptorchBlock("Project Hidden", layer_ipu[start_idx], "lm_head", log_insertions=log_insertions) + ) + return transformations + def parallelize(self): """ Transform the model to run in an IPU pipeline. @@ -426,38 +481,13 @@ def parallelize(self): model = PipelinedWav2Vec2ForPreTraining(config).parallelize().half() ``` """ - super().parallelize() - self.wav2vec2.__class__ = IPUWav2Vec2Model - self.freeze_feature_encoder() - self.change_wav2vec2_encoder_class(False) - self.change_wav2vec2_adapter_class(False) - self.change_conv_eps(False) - - if self.ipu_config.ipus_per_replica != 1: - logger.info("---------- Device Allocation -----------") - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - layers = [] - # Conv layers - for index, layer in enumerate(self.wav2vec2.feature_extractor.conv_layers): - layers.append((f"Conv {index:<2}", layer)) - # Positional Embedding - layers.append(("Positional Embedding", self.wav2vec2.encoder.pos_conv_embed)) - # Encoder layers - for index, layer in enumerate(self.wav2vec2.encoder.layers): - recomputation_checkpoint(layer) - layers.append((f"Encoder {index:<2}", layer)) - # Project Hidden - layers.append(("Project Hidden", self.lm_head)) - - if len(layer_ipu) != len(layers): - raise ValueError(f"Layers per IPU total ({len(layer_ipu)}) must be equal to layers ({len(layers)}).") - - for i, (name, layer) in enumerate(layers): - logger.info(f"{name} --> IPU {layer_ipu[i]}") - self._add_begin_block(layer, name, ipu_id=layer_ipu[i]) - - logger.info("---------------------------------------") + if not isinstance(self, torch.fx.GraphModule): + self.freeze_feature_encoder() + self.change_wav2vec2_encoder_class(False) + self.change_wav2vec2_adapter_class(False) + self.change_conv_eps(False) + return super().parallelize() def deparallelize(self): """ @@ -465,12 +495,11 @@ def deparallelize(self): You should call this before doing `save_pretrained` so that the `model.state_dict` is fully compatible with `transformers.Wav2Vec2ForPreTraining`. """ - super().deparallelize() self.change_wav2vec2_encoder_class(True) self.change_wav2vec2_adapter_class(True) self.change_conv_eps(True) self.wav2vec2.__class__ = Wav2Vec2Model - return self + return super().deparallelize() def forward( self, diff --git a/optimum/graphcore/trainer.py b/optimum/graphcore/trainer.py index 2ba3bb47e..45ec3d6de 100644 --- a/optimum/graphcore/trainer.py +++ b/optimum/graphcore/trainer.py @@ -474,11 +474,15 @@ def compile_model( else: sample_batch = self._prepare_inputs(sample_batch) model = self.model if training else self.model_for_eval + if training: + model.train() + else: + model.eval() + signature = inspect.signature(model.forward) if isinstance(sample_batch, tuple): - signature = inspect.signature(model.forward) model.input_names_for_symbolic_trace = list(signature.parameters.keys())[: len(sample_batch)] else: - model.input_names_for_symbolic_trace = list(sample_batch.keys()) + model.input_names_for_symbolic_trace = [p for p in signature.parameters if p in sample_batch] if not isinstance(model, torch.fx.GraphModule): model = model.parallelize() if not self.args.fp32: @@ -932,7 +936,7 @@ def wrap_model(self, model: Union[PreTrainedModel, PoplarExecutor], training: bo wrapped = self.training_model else: if self.inference_model is None: - self.inference_model = poptorch.inferenceModel(model.eval(), options=self.eval_opts) + self.inference_model = poptorch.inferenceModel(model, options=self.eval_opts) wrapped = self.inference_model # Attaching to device when the model that is being access was already compiled but detached from previous loop. From 22d65c7b5f2f35585fc6f6956c287a03c20bcaec Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 30 Aug 2022 15:26:21 +0200 Subject: [PATCH 13/33] Removed unused code in T5 --- optimum/graphcore/models/t5/modeling_t5.py | 161 +++++---------------- 1 file changed, 34 insertions(+), 127 deletions(-) diff --git a/optimum/graphcore/models/t5/modeling_t5.py b/optimum/graphcore/models/t5/modeling_t5.py index 8f63048f9..59b24ef61 100644 --- a/optimum/graphcore/models/t5/modeling_t5.py +++ b/optimum/graphcore/models/t5/modeling_t5.py @@ -60,10 +60,6 @@ class PipelinedT5ForConditionalGeneration( GenerationMethodsMixin, T5ForConditionalGeneration, PipelineMixin, IPUGenerationMixin ): - @property - def is_encoder_and_decoder_embeddings_computation_shared(self): - return isinstance(self.shared, SharedEmbedding) - def scale_down_weights(self, factor: float = 1, restore: bool = False): self.lm_scale_modifier = 1 if not restore else None # self.lm_scale_modifier = nn.Parameter(torch.ones(self.config.d_model, dtype=torch.float16)) if not restore else None @@ -178,126 +174,37 @@ def deparallelize(self): self = composition(self, reverse=True) return self - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.is_encoder_and_decoder_embeddings_computation_shared: - inputs_embeds, decoder_inputs_embeds = self.shared( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - ) - if inputs_embeds is not None: - input_ids = None - if decoder_inputs_embeds is not None: - decoder_input_ids = None - - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - if self.config.num_layers == self.config.num_decoder_layers: - warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - # Convert encoder inputs in embeddings if needed - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - hidden_states = encoder_outputs[0] - - if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: - # get decoder inputs from shifting lm labels to the right - decoder_input_ids = self._shift_right(labels) - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - inputs_embeds=decoder_inputs_embeds, - past_key_values=past_key_values, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = decoder_outputs[0] - - # Set device for model parallelism - if self.model_parallel: - self.lm_head = self.lm_head.to(self.encoder.first_device) - sequence_output = sequence_output.to(self.lm_head.weight.device) - - if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.model_dim**-0.5) - - lm_scale_modifier = getattr(self, "lm_scale_modifier", None) - if lm_scale_modifier is not None: - sequence_output = sequence_output * lm_scale_modifier - - lm_logits = self.lm_head(sequence_output) - - loss = None - if labels is not None: - loss_fct = nn.CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) - # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 - - # Only returning the loss to make the communication between the host and the device faster. - if not return_dict: - output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs - return (loss,) if labels is not None else output - - if loss is not None: - return Seq2SeqLMOutput( - loss=loss, - ) - return Seq2SeqLMOutput( - loss=loss, - logits=lm_logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) + # def train(self, mode: bool = True) -> "PipelinedT5ForConditionalGeneration": + # mod = super(T5ForConditionalGeneration, self).train(mode=mode) + # # TODO: enable that once generation is supported. + # # mod.forward = mod._forward_for_train if mode else mod._forward_for_generate + # mod.forward = mod._forward_for_train + # return mod + + # def _forward_for_train(self, input_ids, attention_mask, decoder_input_ids, labels=None): + # outputs = super().forward( + # input_ids=input_ids, + # attention_mask=attention_mask, + # decoder_input_ids=decoder_input_ids, + # labels=labels, + # use_cache=False, + # return_dict=False, + # ) + # # Only returning the loss to make the communication between the host and the device faster. + # return outputs[0:1] + + # def _forward_for_generate(self, encoder_outputs, decoder_input_ids, attention_mask, labels=None): + # outputs = super().forward( + # encoder_outputs=encoder_outputs, + # attention_mask=attention_mask, + # decoder_input_ids=decoder_input_ids, + # return_dict=False, + # use_cache=False, + # labels=labels, + # ) + # # Only returning the loss (if labels is provided) and the logits. + # if labels is None: + # return outputs[:1] + # return outputs[:2] + + # forward = _forward_for_train From 41b7814df629f6e9743fc76da9d34153a5dd1c3c Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 31 Aug 2022 16:46:16 +0200 Subject: [PATCH 14/33] Deberta recomputation checkpoint --- optimum/graphcore/fx/transformations.py | 64 +++++++++++++++++-- optimum/graphcore/fx/utils.py | 4 ++ .../models/deberta/modeling_deberta.py | 35 +++++----- 3 files changed, 80 insertions(+), 23 deletions(-) diff --git a/optimum/graphcore/fx/transformations.py b/optimum/graphcore/fx/transformations.py index de8710e68..837e91cf3 100644 --- a/optimum/graphcore/fx/transformations.py +++ b/optimum/graphcore/fx/transformations.py @@ -15,7 +15,7 @@ import collections import operator import re -from typing import TYPE_CHECKING, Callable, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -43,6 +43,11 @@ def parent_module_qualified_name(node: "Node") -> str: return name if name != "root" else "" +def parent_module_type(node: "Node") -> Union[str, Type]: + return getattr(node, "parent_module_type", None) + + + class AddPoptorchBlockBase(ReversibleTransformation): """ Base class that provide useful methods for inserting poptorch blocks in the model. @@ -249,9 +254,14 @@ class RecomputationCheckpoint(ReversibleTransformation): Annotates the output of a module to be checkpointed instead of recomputed. """ - def __init__(self, name_regex: str, to_exclude: Optional[str] = None): + def __init__(self, name_regex: str, to_exclude: Optional[str] = None, output_nodes_specs: Dict[str, List[Any]] = None): self.name_regex = re.compile(name_regex) self.to_exclude = re.compile(to_exclude) if to_exclude is not None else None + self.output_nodes_specs = None + if output_nodes_specs is not None: + self.output_nodes_specs = collections.defaultdict(set) + for k, v in output_nodes_specs.items(): + self.output_nodes_specs[k] = v def find_output_nodes_for_module_name(self, graph_module: "GraphModule", module_qualified_name: str): nodes_in_module = set() @@ -277,8 +287,11 @@ def find_output_nodes_for_module_name(self, graph_module: "GraphModule", module_ else: break modules_from_the_past.add(name) - # nodes_in_module = {n for n in nodes_in_module if set(n.users.keys()) & nodes_in_module} - return [n for n in nodes_in_module if set(n.users.keys()) - nodes_in_module] + output_nodes = [n for n in nodes_in_module if set(n.users.keys()) - nodes_in_module] + if self.output_nodes_specs: + output_nodes = [n for n in output_nodes if n.target in self.output_nodes_specs[n.op]] + return output_nodes + def transform(self, graph_module: "GraphModule") -> "GraphModule": matched_module_names = collections.OrderedDict() @@ -310,6 +323,49 @@ def reverse(self, graph_module: "GraphModule") -> "GraphModule": return graph_module +class AutoCast(ReversibleTransformation): + def __init__(self, module_types: Union[Type, Set[Type]], name_regex: Optional[str] = None): + self.module_types = module_types + if not isinstance(self.module_types, (list, tuple, set)): + self.module_types = {self.module_types} + if not isinstance(self.module_types, set): + self.module_types = set(self.module_types) + self.name_regex = re.compile(name_regex) if name_regex is not None else None + + def find_start_and_end_nodes(self, graph_module: "GraphModule") -> List[Tuple["Node", "Node"]]: + start_and_end_nodes = [] + start_node = None + end_node = None + for node in graph_module.graph.nodes: + name = parent_module_qualified_name(node) + if self.name_regex is not None and not re.match(self.name_regex, name): + continue + type_ = parent_module_type(node) + if type_ in self.module_types and start_node is None: + start_node = node + elif type_ in self.module_types: + end_node = node + elif start_node is not None: + start_and_end_nodes.append((start_node, end_node)) + start_node = None + return start_and_end_nodes + + def transform(self, graph_module: "GraphModule") -> "GraphModule": + start_and_end_nodes = self.find_start_and_end_nodes(graph_module) + for start_node, end_node in start_and_end_nodes: + with graph_module.graph.inserting_before(start_node): + graph_module.graph.call_function(torch.ops.poptorch.begin_autocast) + with graph_module.graph.inserting_after(end_node): + graph_module.graph.call_function(torch.ops.poptorch.restore_autocast) + return graph_module + + def reverse(self, graph_module: "GraphModule") -> "GraphModule": + for node in graph_module.graph.nodes: + if node.target in [torch.ops.poptorch.begin_autocast, torch.ops.poptorch.restore_autocast]: + graph_module.graph.erase_node(node) + return graph_module + + class VocabEmbeddingToSerializedEmbedding(ReversibleTransformation): """ Transforms the embedding layer matching name_regex to a SerializedEmbedding layer. diff --git a/optimum/graphcore/fx/utils.py b/optimum/graphcore/fx/utils.py index c149e8998..1588fae90 100644 --- a/optimum/graphcore/fx/utils.py +++ b/optimum/graphcore/fx/utils.py @@ -36,6 +36,7 @@ def __init__(self, autowrap_modules=(math,), autowrap_functions=()): super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions) self.ops_to_wrap = [] self.current_module_qualified_name = ["root"] + self.current_module_type = ["root"] def register_op_to_wrap(self, name, wrapper, orig_op): self.ops_to_wrap.append((name, wrapper, orig_op)) @@ -61,6 +62,7 @@ def proxy(self, node): # Would be better to update the created node in TracerBase.create_node, but this method has less arguments, so # it is easier to use this one, and equivalent. node.parent_module_qualified_name = self.current_module_qualified_name[-1] + node.parent_module_type = self.current_module_type[-1] proxy = super().proxy(node) return proxy @@ -71,10 +73,12 @@ def call_module(self, m, forward, args, kwargs): is_leaf_module = self.is_leaf_module(m, module_qualified_name) if not is_leaf_module: self.current_module_qualified_name.append(module_qualified_name) + self.current_module_type.append(type(m)) self.orig_forward = forward proxy = super().call_module(m, forward, args, kwargs) if not is_leaf_module: self.current_module_qualified_name.pop(-1) + self.current_module_type.pop(-1) return proxy diff --git a/optimum/graphcore/models/deberta/modeling_deberta.py b/optimum/graphcore/models/deberta/modeling_deberta.py index 412b0bb00..318abdbd9 100644 --- a/optimum/graphcore/models/deberta/modeling_deberta.py +++ b/optimum/graphcore/models/deberta/modeling_deberta.py @@ -13,11 +13,10 @@ # limitations under the License. import math -from typing import Optional, Tuple, Union +import operator import torch import torch.nn as nn -import torch.nn.functional as F import poptorch from transformers import DebertaForQuestionAnswering, DebertaForSequenceClassification, DebertaForTokenClassification @@ -33,6 +32,7 @@ from ...fx.transformations import ( AddPoptorchBlock, AddPoptorchBlocksInSeries, + AutoCast, ClipValues, ClipValuesSymmetric, OutlineAttribute, @@ -47,7 +47,7 @@ logger = logging.get_logger(__name__) _OPTIMIZATION_TRANSFORMATIONS = [ - # ChangeTrueDivToMulByInverse(), + ChangeTrueDivToMulByInverse(), MergeLinears(), # FuseBiasInLinear(), ] @@ -65,9 +65,6 @@ class FastGatherLastDim(nn.Module): on the last dimension of a tensor. """ - def __init__(self): - super().__init__() - def forward(self, data, idx, target=None): if poptorch.isRunningOnIpu(): if target is None: @@ -89,9 +86,6 @@ def forward(self, data, idx, target=None): return torch.gather(data, -1, idx) -gather_last_dim = FastGatherLastDim() - - class XSoftmax(torch.nn.Module): def __init__(self, dim): super().__init__() @@ -130,6 +124,7 @@ class IPUDisentangledSelfAttention(DisentangledSelfAttention): def __init__(self, config): super().__init__(config) self.xsoftmax = XSoftmax(-1) + self.gather_last_dim = FastGatherLastDim() def forward( self, @@ -252,7 +247,7 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd index = c2p_pos.expand( [query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)] ) - c2p_att = gather_last_dim(c2p_att, index) + c2p_att = self.gather_last_dim(c2p_att, index) score += c2p_att # position->content @@ -272,7 +267,7 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd if query_layer.size(-2) != key_layer.size(-2): pos_index = relative_pos[:, :, :, 0].unsqueeze(-1) index = pos_index.expand(pos_index, p2c_att, key_layer) - p2c_att = gather_last_dim(p2c_att, index) + p2c_att = self.gather_last_dim(p2c_att, index) score += p2c_att return score @@ -286,7 +281,8 @@ def change_modules_for_ipu(self, restore: bool): if restore: del mod.xsoftmax else: - mod.xsoftmax = XSoftmax(-1) + mod.add_module("xsoftmax", XSoftmax(-1)) + mod.add_module("gather_last_dim", FastGatherLastDim()) if restore: if isinstance(mod, nn.Dropout): mod.__class__ = StableDropout @@ -316,13 +312,14 @@ def get_transformations(self): AddPoptorchBlock("Classifier Output", layer_ipu[-1], "classifier", log_insertions=log_insertions), AddPoptorchBlock("QA Outputs", layer_ipu[-1], "qa_outputs", log_insertions=log_insertions), ] - # if self.ipu_config.recompute_checkpoint_every_layer: - # transformations.append( - # RecomputationCheckpoint( - # "deberta.encoder.layer.[0-9]+", - # to_exclude=f"deberta.encoder.layer.{self.config.num_hidden_layers - 1}", - # ) - # ) + if self.ipu_config.recompute_checkpoint_every_layer: + transformations.append( + RecomputationCheckpoint( + "deberta.encoder.layer.[0-9]+", + to_exclude=f"deberta.encoder.layer.{self.config.num_hidden_layers - 1}", + output_nodes_specs={"call_function": [operator.add]}, + ) + ) if self.ipu_config.embedding_serialization_factor > 1: transformations.append(VocabEmbeddingToSerializedEmbedding()) return transformations From c2eed06d35631550a420c3d3a297c3582bac0c46 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 20 Sep 2022 12:23:44 +0200 Subject: [PATCH 15/33] Fix BartForSequenceClassification --- optimum/graphcore/models/bart/modeling_bart.py | 6 ++---- tests/test_examples.py | 6 ++++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/optimum/graphcore/models/bart/modeling_bart.py b/optimum/graphcore/models/bart/modeling_bart.py index a906cc376..a9943fd19 100644 --- a/optimum/graphcore/models/bart/modeling_bart.py +++ b/optimum/graphcore/models/bart/modeling_bart.py @@ -32,6 +32,7 @@ ShareEmbeddingComputation, TieWeights, TupleOutput, + VocabEmbeddingToSerializedEmbedding, ) from ...fx.utils import symbolic_trace_pipelined_model from ...generation_utils import IPUGenerationMixin @@ -384,10 +385,7 @@ def get_transformations(self): if not isinstance(self, torch.fx.GraphModule): if self.ipu_config.embedding_serialization_factor > 1: - transformations += [ - LinearToSerializedLinear("lm_head"), - TieWeights("model.shared", "lm_head"), - ] + transformations.append(VocabEmbeddingToSerializedEmbedding()) transformations += [ShareEmbeddingComputation()] return transformations diff --git a/tests/test_examples.py b/tests/test_examples.py index 7f2278ba0..5427b4413 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -374,6 +374,12 @@ def _install_requirements(self, requirements_filename: Union[str, os.PathLike]): return_code = p.wait() self.assertEqual(return_code, 0) + # TODO: remove this. + cmd_line = f"{pip_name} install git+https://github.com/huggingface/optimum.git".split() + p = subprocess.Popen(cmd_line) + return_code = p.wait() + self.assertEqual(return_code, 0) + # Install requirements if not Path(requirements_filename).exists(): return From 2a5c7165a30a52eb17e1da4cb5da4dc3189947aa Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 11 Oct 2022 14:31:24 +0200 Subject: [PATCH 16/33] Fixes --- .../graphcore/models/bart/modeling_bart.py | 3 +- .../graphcore/models/bert/modeling_bert.py | 6 +-- .../models/convnext/modeling_convnext.py | 38 +------------------ .../models/deberta/modeling_deberta.py | 6 +-- .../models/lxmert/modeling_lxmert.py | 6 +-- .../models/wav2vec2/modeling_wav2vec2.py | 1 + optimum/graphcore/trainer.py | 2 +- 7 files changed, 14 insertions(+), 48 deletions(-) diff --git a/optimum/graphcore/models/bart/modeling_bart.py b/optimum/graphcore/models/bart/modeling_bart.py index a9943fd19..2358d1bb1 100644 --- a/optimum/graphcore/models/bart/modeling_bart.py +++ b/optimum/graphcore/models/bart/modeling_bart.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -20,6 +20,7 @@ from optimum.utils import logging from transformers import BartForConditionalGeneration, BartForSequenceClassification from transformers.models.bart.modeling_bart import BartAttention +from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqSequenceClassifierOutput from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, ReversibleTransformation, compose from ...fx.transformations import ( diff --git a/optimum/graphcore/models/bert/modeling_bert.py b/optimum/graphcore/models/bert/modeling_bert.py index b5760d98b..2db8aee3a 100644 --- a/optimum/graphcore/models/bert/modeling_bert.py +++ b/optimum/graphcore/models/bert/modeling_bert.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math - from typing import Optional, Tuple, Union import torch @@ -22,15 +20,17 @@ import poptorch from optimum.utils import logging from scipy.stats import truncnorm -from transformers import ( +from transformers.models.bert.modeling_bert import ( BertForMaskedLM, BertForMultipleChoice, BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, + BertForPreTrainingOutput, ) from transformers.utils.fx import _gen_constructor_wrapper +from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose from ...fx.transformations import ( diff --git a/optimum/graphcore/models/convnext/modeling_convnext.py b/optimum/graphcore/models/convnext/modeling_convnext.py index d0a9983a2..9391bd295 100644 --- a/optimum/graphcore/models/convnext/modeling_convnext.py +++ b/optimum/graphcore/models/convnext/modeling_convnext.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -import poptorch -import transformers +from torch import nn from transformers.models.convnext.modeling_convnext import ConvNextLayer, ConvNextLayerNorm, ConvNextForImageClassification from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose @@ -47,7 +46,6 @@ ] -<<<<<<< HEAD class IPUConvNextLayerNorm(nn.Module): r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, @@ -115,7 +113,6 @@ def parallelize(self): for layer in stage.layers: layer.__class__ = OptimizedConvNextLayer -<<<<<<< HEAD # Enable autocast for ConvNextLayerNorm because computation cannot happen in fp16 for mod in self.modules(): if isinstance(mod, ConvNextLayerNorm): @@ -148,36 +145,3 @@ def deparallelize(self): composition = compose(*transformations) self = composition(self, reverse=True) return self - # def parallelize(self): - # super().parallelize() - - # # Use optimized ConvNextLayer - # for stage in self.convnext.encoder.stages: - # for layer in stage.layers: - # layer.__class__ = OptimizedConvNextLayer - - # # Enable autocast for ConvNextLayerNorm because computation cannot happen in fp16 - # for mod in self.modules(): - # if isinstance(mod, ConvNextLayerNorm): - # mod.forward = poptorch.autocast(enabled=True)(mod.forward) - - # logger.info("---------- Device Allocation -----------") - # logger.info(f"Embedding --> IPU 0") - # self.convnext.embeddings = poptorch.BeginBlock(self.convnext.embeddings, "Embedding", ipu_id=0) - - # layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - # global_layer_idx = 0 - # for stage_idx, stage in enumerate(self.convnext.encoder.stages): - # for layer_idx, layer in enumerate(stage.layers): - # ipu = layer_ipu[global_layer_idx] - # logger.info(f"Encoder stage {stage_idx}, convnext layer {layer_idx} --> IPU {ipu}") - # layer = poptorch.BeginBlock(layer, f"Encoder_stage_{stage_idx}_layer_{layer_idx}", ipu_id=ipu) - # global_layer_idx += 1 - - # last_ipu = self.ipu_config.ipus_per_replica - 1 - # logger.info(f"Head --> IPU {last_ipu}") - # logger.info("---------------------------------------") - # self.convnext.layernorm = poptorch.BeginBlock(self.convnext.layernorm, "LayerNorm", ipu_id=last_ipu) - # self.classifier = poptorch.BeginBlock(self.classifier, "Classifier", ipu_id=last_ipu) - - # return self diff --git a/optimum/graphcore/models/deberta/modeling_deberta.py b/optimum/graphcore/models/deberta/modeling_deberta.py index 318abdbd9..8770d14c4 100644 --- a/optimum/graphcore/models/deberta/modeling_deberta.py +++ b/optimum/graphcore/models/deberta/modeling_deberta.py @@ -14,18 +14,20 @@ import math import operator +from typing import Optional, Union, Tuple import torch import torch.nn as nn import poptorch -from transformers import DebertaForQuestionAnswering, DebertaForSequenceClassification, DebertaForTokenClassification +from transformers import DebertaForMaskedLM, DebertaForQuestionAnswering, DebertaForSequenceClassification, DebertaForTokenClassification from transformers.models.deberta.modeling_deberta import ( DebertaEncoder, DisentangledSelfAttention, StableDropout, build_relative_position, ) +from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose from ....utils import logging @@ -445,8 +447,6 @@ class PipelinedDebertaForSequenceClassification(DebertaForSequenceClassification ``` """ - # def forward(self, input_ids, attention_mask, token_type_ids, labels=None): - # return_dict = False @register(DebertaForTokenClassification) class PipelinedDebertaForTokenClassification(DebertaForTokenClassification, DebertaPipelineMixin): diff --git a/optimum/graphcore/models/lxmert/modeling_lxmert.py b/optimum/graphcore/models/lxmert/modeling_lxmert.py index 65c1cb4a0..ce4a5eab9 100644 --- a/optimum/graphcore/models/lxmert/modeling_lxmert.py +++ b/optimum/graphcore/models/lxmert/modeling_lxmert.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Union, Tuple + import torch import torch.nn.functional as F -import poptorch -import transformers - +from transformers.models.lxmert.modeling_lxmert import LxmertForQuestionAnswering, LxmertForQuestionAnsweringOutput from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose from ....utils import logging from ...fx.transformations import ( diff --git a/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py b/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py index 5111f2769..15a9b6546 100755 --- a/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py @@ -27,6 +27,7 @@ Wav2Vec2EncoderStableLayerNorm, Wav2Vec2ForCTC, Wav2Vec2GumbelVectorQuantizer, + Wav2Vec2ForPreTrainingOutput, ) from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose diff --git a/optimum/graphcore/trainer.py b/optimum/graphcore/trainer.py index 45ec3d6de..53a4b3381 100644 --- a/optimum/graphcore/trainer.py +++ b/optimum/graphcore/trainer.py @@ -909,7 +909,7 @@ def num_examples(self, dataloader: poptorch.DataLoader) -> int: """ return len(dataloader.dataset) -def wrap_model(self, model: Union[PreTrainedModel, PoplarExecutor], training: bool =True) -> PoplarExecutor: + def wrap_model(self, model: Union[PreTrainedModel, PoplarExecutor], training: bool = True) -> PoplarExecutor: """ Wraps a model for PopTorch, either for training or for inference. From aa4e93ee35b7c9ab00365bfb8fe065a64e48420e Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 11 Oct 2022 14:47:07 +0200 Subject: [PATCH 17/33] Fixes --- optimum/graphcore/trainer.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/optimum/graphcore/trainer.py b/optimum/graphcore/trainer.py index 53a4b3381..d79d3cc3d 100644 --- a/optimum/graphcore/trainer.py +++ b/optimum/graphcore/trainer.py @@ -381,7 +381,7 @@ def __init__( if args.do_train: train_dl = self.get_train_dataloader() model = self.wrap_model(self.model) - self.compile_model(model, next(iter(train_dl)), log=True) + self.compile_model(self.training_model, next(iter(train_dl)), log=True) if args.do_eval: # Same thing with _wrap_and_compile_for_evaluation eval_dl = self.get_eval_dataloader() @@ -1091,21 +1091,13 @@ def _inner_training_loop( if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - # Activate gradient checkpointing if needed - if args.gradient_checkpointing: - self.model.gradient_checkpointing_enable() - - model = self._compile_model(next(iter(train_dataloader)), training=True, log=True) - self.create_optimizer_and_scheduler(num_training_steps=max_steps) - self.state = IPUTrainerState() if trial is not None: raise ValueError("Hyperparameter tuning is not supported by the IPUTrainer.") trial = None self.state.is_hyper_param_search = trial is not None - # TODO: brought by sdk3.0 pr self.training_model = self.wrap_model(self.model) # TODO: handle optimizer and scheduler creation @@ -1115,7 +1107,6 @@ def _inner_training_loop( # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) - # TODO: brought by sdk3.0 pr self.compile_model(self.training_model, next(iter(train_dataloader)), log=True) # Train! @@ -1779,7 +1770,7 @@ def evaluate( # Running this here (even though it is being recalled in self.evaluation_loop to make compilation happen here. # That way, compilation will not mess inference speed metrics. - _ = self._compile_model(next(iter(eval_dataloader)), training=False) + _ = self._wrap_and_compile_model_for_evaluation(eval_dataloader, prediction_loss_only) start_time = time.time() @@ -1900,7 +1891,7 @@ def evaluation_loop( prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only ) - model = self._compile_model(next(iter(dataloader)), training=False) + self.inference_model = self._wrap_and_compile_model_for_evaluation(dataloader, prediction_loss_only) batch_size = dataloader.batch_size From 7fd6bc992c7a25d2d426c790cfac8338d442bc9a Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 11 Oct 2022 15:38:54 +0200 Subject: [PATCH 18/33] Fixes --- optimum/graphcore/modeling_utils.py | 2 +- optimum/graphcore/trainer.py | 27 ++++++++++++--------------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/optimum/graphcore/modeling_utils.py b/optimum/graphcore/modeling_utils.py index e749eae2b..23934e2b8 100644 --- a/optimum/graphcore/modeling_utils.py +++ b/optimum/graphcore/modeling_utils.py @@ -446,6 +446,6 @@ def forward(self, sequence, positions): """ Gather the vectors at the specific positions over a batch. """ - num_classes = int(sequence.shape[1]) + num_classes = sequence.shape[1] one_hot_positions = F.one_hot(positions, num_classes).to(dtype=sequence.dtype) return torch.matmul(one_hot_positions.detach(), sequence) diff --git a/optimum/graphcore/trainer.py b/optimum/graphcore/trainer.py index d79d3cc3d..a76615f1a 100644 --- a/optimum/graphcore/trainer.py +++ b/optimum/graphcore/trainer.py @@ -380,12 +380,10 @@ def __init__( logger.info("Called with compile_only=True. Compiling models then exiting.") if args.do_train: train_dl = self.get_train_dataloader() - model = self.wrap_model(self.model) - self.compile_model(self.training_model, next(iter(train_dl)), log=True) + self.compile_model(next(iter(train_dl)), log=True) if args.do_eval: - # Same thing with _wrap_and_compile_for_evaluation eval_dl = self.get_eval_dataloader() - model = self._wrap_and_compile_model_for_evaluation(eval_dl, False) + self.compile_model(next(iter(eval_dl)), training=False) logger.info("Exiting after compiling models with compile_only=True") sys.exit(0) @@ -491,7 +489,7 @@ def compile_model( self.model = model else: self.model_for_eval = model - model = self._wrap_model(model, training=training) + model = self.wrap_model(model, training=training) if log: logger.info("Compiling Model...") start_compile = time.perf_counter() @@ -914,7 +912,7 @@ def wrap_model(self, model: Union[PreTrainedModel, PoplarExecutor], training: bo Wraps a model for PopTorch, either for training or for inference. Args: - model ([`~transformers.modeling_utils.PreTrainedModel`] or `poptorch.PoplarExecutor`): + model ([`transformers.PreTrainedModel`] or `poptorch.PoplarExecutor`): The model to wrap. training (`bool`, *optional*, defaults to `True`): Whether to wrap the model for training or not. @@ -1091,24 +1089,23 @@ def _inner_training_loop( if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - self.create_optimizer_and_scheduler(num_training_steps=max_steps) if trial is not None: raise ValueError("Hyperparameter tuning is not supported by the IPUTrainer.") trial = None self.state.is_hyper_param_search = trial is not None - self.training_model = self.wrap_model(self.model) + + # self.training_model = self.wrap_model(self.model) + self.traning_model = self.compile_model(next(iter(train_dataloader)), training=True) + + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + self._load_optimizer_and_scheduler(resume_from_checkpoint) # TODO: handle optimizer and scheduler creation # if delay_optimizer_creation: # self.create_optimizer_and_scheduler(num_training_steps=max_steps) - # Check if saved optimizer or scheduler states exist - self._load_optimizer_and_scheduler(resume_from_checkpoint) - - self.compile_model(self.training_model, next(iter(train_dataloader)), log=True) - # Train! num_examples = ( self.num_examples(train_dataloader) @@ -1770,7 +1767,7 @@ def evaluate( # Running this here (even though it is being recalled in self.evaluation_loop to make compilation happen here. # That way, compilation will not mess inference speed metrics. - _ = self._wrap_and_compile_model_for_evaluation(eval_dataloader, prediction_loss_only) + _ = self.compile_model(next(iter(eval_dataloader)), training=False) start_time = time.time() @@ -1891,7 +1888,7 @@ def evaluation_loop( prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only ) - self.inference_model = self._wrap_and_compile_model_for_evaluation(dataloader, prediction_loss_only) + self.inference_model = self.compile_model(next(iter(dataloader)), training=False) batch_size = dataloader.batch_size From b8968623739457d280176b542716a16863ac529d Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 13 Oct 2022 18:43:29 +0200 Subject: [PATCH 19/33] Fix issues --- optimum/graphcore/fx/utils.py | 14 + .../graphcore/models/bart/modeling_bart.py | 192 +++++++------- .../models/deberta/modeling_deberta.py | 18 +- .../models/distilbert/modeling_distilbert.py | 246 ++++++++++-------- optimum/graphcore/trainer.py | 1 + 5 files changed, 259 insertions(+), 212 deletions(-) diff --git a/optimum/graphcore/fx/utils.py b/optimum/graphcore/fx/utils.py index 1588fae90..a8c53bcb0 100644 --- a/optimum/graphcore/fx/utils.py +++ b/optimum/graphcore/fx/utils.py @@ -24,6 +24,9 @@ class PipelinedTracer(HFTracer): + # TODO: keep this until transformers >= 4.23.2 + _TORCH_METHODS_TO_PATCH = list(HFTracer._TORCH_METHODS_TO_PATCH) + _TORCH_METHODS_TO_PATCH.append("clamp") """ Tracer that enables tracing and transforming models to run them on IPUs. Compared to the HFTracer, this one adds the following features: @@ -81,6 +84,17 @@ def call_module(self, m, forward, args, kwargs): self.current_module_type.pop(-1) return proxy + def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None): + # TODO: how to handle the case where the model is ran in full-precision? + float32_dtype_in_args = any(a is torch.float32 for a in args) + float32_dtype_in_kwargs = kwargs.get("dtype", None) is torch.float32 + if kind == "call_method" and target == "to": + if float32_dtype_in_args: + args = tuple(a if a is not torch.float32 else torch.float16 for a in args) + if float32_dtype_in_kwargs: + kwargs["dtype"] = torch.float16 + return super().create_proxy(kind, target, args, kwargs, name=name, type_expr=type_expr, proxy_factory_fn=proxy_factory_fn) + def symbolic_trace_with_pipelined_tracer( model: PipelineMixin, diff --git a/optimum/graphcore/models/bart/modeling_bart.py b/optimum/graphcore/models/bart/modeling_bart.py index 2358d1bb1..6649468d6 100644 --- a/optimum/graphcore/models/bart/modeling_bart.py +++ b/optimum/graphcore/models/bart/modeling_bart.py @@ -435,99 +435,99 @@ def deparallelize(self): self = composition(self, reverse=True) return self - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - decoder_head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if labels is not None: - use_cache = False - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] # last hidden state - B, L, E = hidden_states.shape - - eos_mask = torch.eq(input_ids, self.config.eos_token_id) - # Static tensor shape version of hidden_states[eos_mask, :] - eos_indices = eos_mask * torch.arange(L).unsqueeze(0) - last_eos_index, _ = torch.max(eos_indices, dim=1) - # torch.index_select requires a 1D tensor of indices - last_eos_index += torch.arange(B) * L - hidden_states = hidden_states.view(B * L, E) - sentence_representation = torch.index_select(hidden_states, 0, last_eos_index) - - logits = self.classification_head(sentence_representation) - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.config.num_labels == 1: - self.config.problem_type = "regression" - elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = nn.MSELoss() - if self.config.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = nn.BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return Seq2SeqSequenceClassifierOutput( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) + # def forward( + # self, + # input_ids: torch.LongTensor = None, + # attention_mask: Optional[torch.Tensor] = None, + # decoder_input_ids: Optional[torch.LongTensor] = None, + # decoder_attention_mask: Optional[torch.LongTensor] = None, + # head_mask: Optional[torch.Tensor] = None, + # decoder_head_mask: Optional[torch.Tensor] = None, + # cross_attn_head_mask: Optional[torch.Tensor] = None, + # encoder_outputs: Optional[List[torch.FloatTensor]] = None, + # inputs_embeds: Optional[torch.FloatTensor] = None, + # decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + # labels: Optional[torch.LongTensor] = None, + # use_cache: Optional[bool] = None, + # output_attentions: Optional[bool] = None, + # output_hidden_states: Optional[bool] = None, + # return_dict: Optional[bool] = None, + # ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: + # r""" + # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + # Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + # config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + # """ + # return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # if labels is not None: + # use_cache = False + + # outputs = self.model( + # input_ids, + # attention_mask=attention_mask, + # decoder_input_ids=decoder_input_ids, + # decoder_attention_mask=decoder_attention_mask, + # head_mask=head_mask, + # decoder_head_mask=decoder_head_mask, + # cross_attn_head_mask=cross_attn_head_mask, + # encoder_outputs=encoder_outputs, + # inputs_embeds=inputs_embeds, + # decoder_inputs_embeds=decoder_inputs_embeds, + # use_cache=use_cache, + # output_attentions=output_attentions, + # output_hidden_states=output_hidden_states, + # return_dict=return_dict, + # ) + + # hidden_states = outputs[0] # last hidden state + # B, L, E = hidden_states.shape + + # eos_mask = torch.eq(input_ids, self.config.eos_token_id) + # # Static tensor shape version of hidden_states[eos_mask, :] + # eos_indices = eos_mask * torch.arange(L).unsqueeze(0) + # last_eos_index, _ = torch.max(eos_indices, dim=1) + # # torch.index_select requires a 1D tensor of indices + # last_eos_index += torch.arange(B) * L + # hidden_states = hidden_states.view(B * L, E) + # sentence_representation = torch.index_select(hidden_states, 0, last_eos_index) + + # logits = self.classification_head(sentence_representation) + + # loss = None + # if labels is not None: + # if self.config.problem_type is None: + # if self.config.num_labels == 1: + # self.config.problem_type = "regression" + # elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + # self.config.problem_type = "single_label_classification" + # else: + # self.config.problem_type = "multi_label_classification" + + # if self.config.problem_type == "regression": + # loss_fct = nn.MSELoss() + # if self.config.num_labels == 1: + # loss = loss_fct(logits.squeeze(), labels.squeeze()) + # else: + # loss = loss_fct(logits, labels) + # elif self.config.problem_type == "single_label_classification": + # loss_fct = nn.CrossEntropyLoss() + # loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + # elif self.config.problem_type == "multi_label_classification": + # loss_fct = nn.BCEWithLogitsLoss() + # loss = loss_fct(logits, labels) + + # if not return_dict: + # output = (logits,) + outputs[1:] + # return ((loss,) + output) if loss is not None else output + + # return Seq2SeqSequenceClassifierOutput( + # loss=loss, + # logits=logits, + # past_key_values=outputs.past_key_values, + # decoder_hidden_states=outputs.decoder_hidden_states, + # decoder_attentions=outputs.decoder_attentions, + # cross_attentions=outputs.cross_attentions, + # encoder_last_hidden_state=outputs.encoder_last_hidden_state, + # encoder_hidden_states=outputs.encoder_hidden_states, + # encoder_attentions=outputs.encoder_attentions, + # ) diff --git a/optimum/graphcore/models/deberta/modeling_deberta.py b/optimum/graphcore/models/deberta/modeling_deberta.py index 8770d14c4..23ac37e5e 100644 --- a/optimum/graphcore/models/deberta/modeling_deberta.py +++ b/optimum/graphcore/models/deberta/modeling_deberta.py @@ -20,6 +20,7 @@ import torch.nn as nn import poptorch +import transformers from transformers import DebertaForMaskedLM, DebertaForQuestionAnswering, DebertaForSequenceClassification, DebertaForTokenClassification from transformers.models.deberta.modeling_deberta import ( DebertaEncoder, @@ -28,6 +29,7 @@ build_relative_position, ) from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput +from transformers.utils.fx import _gen_constructor_wrapper from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose from ....utils import logging @@ -43,7 +45,7 @@ VocabEmbeddingToSerializedEmbedding, ) from ...fx.utils import symbolic_trace_pipelined_model -from ...modeling_utils import PipelineMixin, get_layer_ipu, register +from ...modeling_utils import PipelineMixin, get_layer_ipu, register, OnehotGather logger = logging.get_logger(__name__) @@ -264,7 +266,7 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) index = p2c_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2)) - p2c_att = gather_last_dim(p2c_att, index).transpose(-1, -2) + p2c_att = self.gather_last_dim(p2c_att, index).transpose(-1, -2) if query_layer.size(-2) != key_layer.size(-2): pos_index = relative_pos[:, :, :, 0].unsqueeze(-1) @@ -335,7 +337,9 @@ def parallelize(self): """ super().parallelize() self.change_modules_for_ipu(False) + torch.nn.functional.one_hot, orig = _gen_constructor_wrapper(torch.nn.functional.one_hot) traced = symbolic_trace_pipelined_model(self) + torch.nn.functional.one_hot = orig transformations = self.get_transformations() transformations += _OPTIMIZATION_TRANSFORMATIONS composition = compose(*transformations, inplace=True) @@ -408,10 +412,10 @@ def forward( sequence_output = outputs[0] if labels is not None: - # Select only the masked tokens for the classifier - max_number_of_masked_tokens = int(labels.size(1) * 0.25) - masked_lm_labels, masked_lm_positions = torch.topk(labels, k=max_number_of_masked_tokens, dim=1) - masked_output = self.gather_indices(sequence_output, masked_lm_positions) + if hasattr(self.config, "max_num_masked_tokens"): + # Select only the masked tokens for the classifier + masked_lm_labels, positions = torch.topk(labels, k=self.config.max_num_masked_tokens, dim=1) + masked_output = self.gather_indices(sequence_output, positions) else: # This case should never happen during training masked_output = sequence_output @@ -420,7 +424,7 @@ def forward( masked_lm_loss = None if labels is not None: - masked_lm_loss = F.cross_entropy( + masked_lm_loss = nn.functional.cross_entropy( prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1) ).float() diff --git a/optimum/graphcore/models/distilbert/modeling_distilbert.py b/optimum/graphcore/models/distilbert/modeling_distilbert.py index 92afbe35f..25213403e 100644 --- a/optimum/graphcore/models/distilbert/modeling_distilbert.py +++ b/optimum/graphcore/models/distilbert/modeling_distilbert.py @@ -20,6 +20,24 @@ import torch.nn.functional as F import poptorch + +from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput +from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention +from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ...fx.transformations import ( + AddPoptorchBlock, + AddPoptorchBlocksInSeries, + ClipValues, + ClipValuesSymmetric, + LinearToSerializedLinear, + OutlineAttribute, + RecomputationCheckpoint, + TieWeights, + TupleOutput, + VocabEmbeddingToSerializedEmbedding, +) +from ...fx.utils import symbolic_trace_pipelined_model +from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register from optimum.utils import logging from transformers import ( DistilBertForMaskedLM, @@ -28,23 +46,24 @@ DistilBertForSequenceClassification, DistilBertForTokenClassification, ) -from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput -from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention - -from ...modeling_utils import ( - OnehotGather, - PipelineMixin, - SerializedEmbedding, - SerializedLinear, - get_layer_ipu, - recomputation_checkpoint, - register, -) logger = logging.get_logger(__name__) +_OPTIMIZATION_TRANSFORMATIONS = [ + ChangeTrueDivToMulByInverse(), + MergeLinears(), + # FuseBiasInLinear(), +] +_NON_REVERSIBLE_TRANSFORMATIONS = [ + ClipValuesSymmetric(1e4, exclude_targets=["view"]), + ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), + TupleOutput(), +] + + +# TODO: should we make a fx transformation for this? class IPUMultiHeadSelfAttention(MultiHeadSelfAttention): def forward( self, @@ -89,13 +108,12 @@ def unshape(x: torch.Tensor) -> torch.Tensor: q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) - mask = mask.to(dtype=scores.dtype) # fp16 compatibility # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. + # masked positions, this operation will create a tensor which is 0 for + # positions we want to attend and -10000 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - mask = (1.0 - mask) * -10000.0 + mask = (1 - mask) * -10000 mask = mask.view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) scores = scores + mask # (bs, n_heads, q_length, k_length) @@ -117,35 +135,50 @@ def unshape(x: torch.Tensor) -> torch.Tensor: class DistilBertPipelineMixin(PipelineMixin): + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + last_ipu = len(self.ipu_config.layers_per_ipu) - 1 + transformations = [ + AddPoptorchBlock("Embedding", 0, "distilbert.embeddings", log_insertions=log_insertions), + OutlineAttribute("distilbert.embeddings.LayerNorm", "Embedding"), + AddPoptorchBlocksInSeries( + "Encoder", layer_ipu, r"distilbert.encoder.layer.[0-9]+", log_insertions=log_insertions + ), + # Only one of the following AddPoptorchBlock, will actually add a block. + AddPoptorchBlock("Classifier Output", last_ipu, "classifier", log_insertions=log_insertions), + AddPoptorchBlock("QA Outputs", last_ipu, "qa_outputs", log_insertions=log_insertions), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations.append( + RecomputationCheckpoint( + "distilbert.encoder.layer.[0-9]+", to_exclude=f"distilbert.encoder.layer.{self.config.num_hidden_layers - 1}" + ) + ) + if self.ipu_config.embedding_serialization_factor > 1: + transformations.append(VocabEmbeddingToSerializedEmbedding()) + return transformations + def parallelize(self): """ Transform the model to run in an IPU pipeline. - Adds pipeline stages to the model + - Replaces self-attention layers with fused-qkv self-attention layers + - (If enabled) Replaces the word embedding with a SerializedEmbedding - Adds recomputation checkpoints """ super().parallelize() - - for layer in self.distilbert.transformer.layer: - layer.attention.__class__ = IPUMultiHeadSelfAttention - - logger.info("-------------------- Device Allocation --------------------") - logger.info("Embedding --> IPU 0") - is_masked_lm = isinstance(self, DistilBertForMaskedLM) - if self.ipu_config.embedding_serialization_factor > 1 and not is_masked_lm: - self.distilbert.embeddings.word_embeddings = SerializedEmbedding( - self.distilbert.embeddings.word_embeddings, self.ipu_config.embedding_serialization_factor - ) - self.distilbert.embeddings = poptorch.BeginBlock(self.distilbert.embeddings, "Embedding", 0) - - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - for index, layer in enumerate(self.distilbert.transformer.layer): - ipu = layer_ipu[index] - if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_hidden_layers - 1: - recomputation_checkpoint(layer) - self.distilbert.transformer.layer[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu) - logger.info(f"Encoder {index:<2} --> IPU {ipu}") - - return self + for mod in self.modules(): + if isinstance(mod, MultiHeadSelfAttention): + mod.__class__ = IPUMultiHeadSelfAttention + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced def deparallelize(self): """ @@ -154,14 +187,13 @@ def deparallelize(self): compatible with the original model. """ super().deparallelize() - - for layer in self.distilbert.transformer.layer: - layer.attention.__class__ = MultiHeadSelfAttention - - is_masked_lm = isinstance(self, DistilBertForMaskedLM) - if self.ipu_config.embedding_serialization_factor > 1 and not is_masked_lm: - self.distilbert.embeddings.word_embeddings = self.distilbert.embeddings.word_embeddings.deserialize() - + for mod in self.modules(): + if isinstance(mod, IPUMultiHeadSelfAttention): + mod.__class__ = MultiHeadSelfAttention + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + self = composition(self, reverse=True) return self @@ -171,40 +203,67 @@ def __init__(self, config): super().__init__(config) self.gather_indices = OnehotGather() + # TODO: validate that. + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + transformations = [ + AddPoptorchBlock("Embedding", 0, "distilbert.embeddings", log_insertions=log_insertions), + OutlineAttribute("distilbert.embeddings.LayerNorm", "Embedding"), + AddPoptorchBlocksInSeries( + "Encoder", layer_ipu, r"distilbert.encoder.layer.[0-9]+", log_insertions=log_insertions + ), + AddPoptorchBlock("Classifier Output", 0, "cls", log_insertions=log_insertions), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations.append( + RecomputationCheckpoint( + "distilbert.encoder.layer.[0-9]+", to_exclude=f"distilbert.encoder.layer.{self.config.num_hidden_layers - 1}" + ) + ) + if self.ipu_config.embedding_serialization_factor > 1: + transformations += [ + LinearToSerializedLinear("cls.predictions.decoder"), + TieWeights("distilbert.embeddings.word_embeddings", "cls.predictions.decoder"), + ] + return transformations + def parallelize(self): + """ + Transform the model to run in an IPU pipeline. + - Adds pipeline stages to the model + - Replaces self-attention layers with fused-qkv self-attention layers + - (If enabled) Replaces the word embedding projection with a SerializedLinear layer + - Adds recomputation checkpoints + """ super().parallelize() - - if self.ipu_config.embedding_serialization_factor > 1: - serialized_vocab_projector = SerializedLinear( - self.config.dim, - self.config.vocab_size, - self.ipu_config.embedding_serialization_factor, - bias=True, - mode=poptorch.MatMulSerializationMode.OutputChannels, - ) - serialized_vocab_projector.load_state_dict(self.vocab_projector.state_dict()) - self.vocab_projector = serialized_vocab_projector - self.tie_weights() - - logger.info("LM Head --> IPU 0") - self.vocab_transform = poptorch.BeginBlock(self.vocab_transform, "LM Head", ipu_id=0) - self.vocab_layer_norm = poptorch.BeginBlock(self.vocab_layer_norm, "LM Head", ipu_id=0) - self.vocab_projector = poptorch.BeginBlock(self.vocab_projector, "LM Head", ipu_id=0) - logger.info("-----------------------------------------------------------") - return self + for mod in self.modules(): + if isinstance(mod, MultiHeadSelfAttention): + mod.__class__ = IPUMultiHeadSelfAttention + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced def deparallelize(self): + """ + Undo the changes to the model done by `parallelize`. + You should call this before doing `save_pretrained` so that the `model.state_dict` is + compatible with the original model. + """ super().deparallelize() - - if self.ipu_config.embedding_serialization_factor > 1: - vocab_projector = nn.Linear( - self.config.hidden_size, - self.config.vocab_size, - bias=True, - ) - vocab_projector.load_state_dict(self.vocab_projector.state_dict()) - self.vocab_projector = vocab_projector - self.tie_weights() + transformations = self.get_transformations() + transformations += _OPTIMIZATION_TRANSFORMATIONS + composition = compose(*transformations) + self = composition(self, reverse=True) + for mod in self.modules(): + if isinstance(mod, IPUMultiHeadSelfAttention): + mod.__class__ = MultiHeadSelfAttention + return self def forward( self, @@ -268,27 +327,11 @@ def forward( @register(DistilBertForSequenceClassification) class PipelinedDistilBertForSequenceClassification(DistilBertForSequenceClassification, DistilBertPipelineMixin): - def parallelize(self): - super().parallelize() - - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"Classifier --> IPU {last_ipu}") - self.pre_classifier = poptorch.BeginBlock(self.pre_classifier, "Classifier", ipu_id=last_ipu) - self.classifier = poptorch.BeginBlock(self.classifier, "Classifier", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self + pass @register(DistilBertForQuestionAnswering) class PipelinedDistilBertForQuestionAnswering(DistilBertForQuestionAnswering, DistilBertPipelineMixin): - def parallelize(self): - super().parallelize() - - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"QA Outputs --> IPU {last_ipu}") - self.qa_outputs = poptorch.BeginBlock(self.qa_outputs, "QA Outputs", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self def forward( self, @@ -332,24 +375,9 @@ def forward( @register(DistilBertForTokenClassification) class PipelinedDistilBertForTokenClassification(DistilBertForTokenClassification, DistilBertPipelineMixin): - def parallelize(self): - super().parallelize() - - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"Classifier --> IPU {last_ipu}") - self.classifier = poptorch.BeginBlock(self.classifier, "Classifier", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self + pass @register(DistilBertForMultipleChoice) class PipelinedDistilBertForMultipleChoice(DistilBertForMultipleChoice, DistilBertPipelineMixin): - def parallelize(self): - super().parallelize() - - last_ipu = self.ipu_config.ipus_per_replica - 1 - logger.info(f"Classifier --> IPU {last_ipu}") - self.pre_classifier = poptorch.BeginBlock(self.pre_classifier, "Classifier", ipu_id=last_ipu) - self.classifier = poptorch.BeginBlock(self.classifier, "Classifier", ipu_id=last_ipu) - logger.info("-----------------------------------------------------------") - return self + pass diff --git a/optimum/graphcore/trainer.py b/optimum/graphcore/trainer.py index a76615f1a..0ecd5c090 100644 --- a/optimum/graphcore/trainer.py +++ b/optimum/graphcore/trainer.py @@ -485,6 +485,7 @@ def compile_model( model = model.parallelize() if not self.args.fp32: model.half() + import pdb; pdb.set_trace() if training: self.model = model else: From ddfd342d45158c9beed3e69e52ba4ec8938743be Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 18 Oct 2022 18:28:23 +0200 Subject: [PATCH 20/33] [WIP] tests/test_pipelined_models.py --- optimum/graphcore/fx/transformations.py | 19 +- optimum/graphcore/fx/utils.py | 57 ++++- .../graphcore/models/bart/modeling_bart.py | 200 +++++++++--------- .../graphcore/models/bert/modeling_bert.py | 5 +- .../models/convnext/modeling_convnext.py | 34 ++- .../models/deberta/modeling_deberta.py | 16 +- .../models/distilbert/modeling_distilbert.py | 28 +-- .../graphcore/models/gpt2/modeling_gpt2.py | 1 - .../models/hubert/modeling_hubert.py | 1 - .../models/lxmert/modeling_lxmert.py | 4 +- .../models/roberta/modeling_roberta.py | 1 - optimum/graphcore/models/t5/modeling_t5.py | 151 +++++++++---- optimum/graphcore/models/vit/modeling_vit.py | 1 - .../models/wav2vec2/modeling_wav2vec2.py | 27 ++- optimum/graphcore/trainer.py | 6 +- tests/test_pipelined_models.py | 63 +++--- tests/test_trainer.py | 35 ++- 17 files changed, 390 insertions(+), 259 deletions(-) diff --git a/optimum/graphcore/fx/transformations.py b/optimum/graphcore/fx/transformations.py index 837e91cf3..3c034ef94 100644 --- a/optimum/graphcore/fx/transformations.py +++ b/optimum/graphcore/fx/transformations.py @@ -47,7 +47,6 @@ def parent_module_type(node: "Node") -> Union[str, Type]: return getattr(node, "parent_module_type", None) - class AddPoptorchBlockBase(ReversibleTransformation): """ Base class that provide useful methods for inserting poptorch blocks in the model. @@ -254,7 +253,9 @@ class RecomputationCheckpoint(ReversibleTransformation): Annotates the output of a module to be checkpointed instead of recomputed. """ - def __init__(self, name_regex: str, to_exclude: Optional[str] = None, output_nodes_specs: Dict[str, List[Any]] = None): + def __init__( + self, name_regex: str, to_exclude: Optional[str] = None, output_nodes_specs: Dict[str, List[Any]] = None + ): self.name_regex = re.compile(name_regex) self.to_exclude = re.compile(to_exclude) if to_exclude is not None else None self.output_nodes_specs = None @@ -292,7 +293,6 @@ def find_output_nodes_for_module_name(self, graph_module: "GraphModule", module_ output_nodes = [n for n in output_nodes if n.target in self.output_nodes_specs[n.op]] return output_nodes - def transform(self, graph_module: "GraphModule") -> "GraphModule": matched_module_names = collections.OrderedDict() for node in graph_module.graph.nodes: @@ -502,12 +502,18 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule": class ShareEmbeddingComputation(Transformation): - def _find_nodes_to_move(self, graph_module, embedding_input_node): + def _find_nodes_to_move(self, graph_module, embedding_input_node, shared_embedding_node): + nodes_before_embedding_input_node = set() + for node in graph_module.graph.nodes: + if node is shared_embedding_node: + break + nodes_before_embedding_input_node.add(node) + to_visit = [embedding_input_node] to_move = set() while to_visit: node = to_visit.pop(0) - if node.op != "placeholder": + if node not in nodes_before_embedding_input_node: to_move.add(node) to_visit += node.all_input_nodes ordered_to_move = [] @@ -539,7 +545,7 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule": raise NotImplementedError("Currently support embedding computation sharing for 2.") new_input_nodes = [] for input_node in reversed(embedding_input_nodes[1:]): - nodes_to_move = self._find_nodes_to_move(graph_module, input_node) + nodes_to_move = self._find_nodes_to_move(graph_module, input_node, embedding_nodes[target][0]) old_to_new_mapping = self._move_nodes_after_node(graph_module, nodes_to_move, embedding_input_nodes[0]) for old_node, new_node in old_to_new_mapping.items(): old_node.replace_all_uses_with(new_node) @@ -559,5 +565,4 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule": getitem = graph_module.graph.call_function(operator.getitem, (shared_node, idx + 1)) embedding_node.replace_all_uses_with(getitem) graph_module.graph.erase_node(embedding_node) - return graph_module diff --git a/optimum/graphcore/fx/utils.py b/optimum/graphcore/fx/utils.py index a8c53bcb0..8d5d9f0e4 100644 --- a/optimum/graphcore/fx/utils.py +++ b/optimum/graphcore/fx/utils.py @@ -12,17 +12,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import math -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List, Optional import torch import transformers +from transformers.models.auto import get_values +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, + MODEL_FOR_CTC_MAPPING_NAMES, +) from transformers.utils.fx import HFTracer, get_concrete_args from ..modeling_utils import PipelineMixin +if TYPE_CHECKING: + from transformers import PreTrainedModel + + class PipelinedTracer(HFTracer): # TODO: keep this until transformers >= 4.23.2 _TORCH_METHODS_TO_PATCH = list(HFTracer._TORCH_METHODS_TO_PATCH) @@ -88,12 +98,35 @@ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, pr # TODO: how to handle the case where the model is ran in full-precision? float32_dtype_in_args = any(a is torch.float32 for a in args) float32_dtype_in_kwargs = kwargs.get("dtype", None) is torch.float32 - if kind == "call_method" and target == "to": - if float32_dtype_in_args: - args = tuple(a if a is not torch.float32 else torch.float16 for a in args) - if float32_dtype_in_kwargs: - kwargs["dtype"] = torch.float16 - return super().create_proxy(kind, target, args, kwargs, name=name, type_expr=type_expr, proxy_factory_fn=proxy_factory_fn) + node_types_to_inspect = [ + ("call_method", "to"), + ("call_function", torch.full), + ] + torch_methods_to_patched_version = {orig: wrapped for (orig, wrapped) in self.patched_torch_methods.values()} + for (k, t) in node_types_to_inspect: + if kind == k and target == torch_methods_to_patched_version.get(t, t): + if float32_dtype_in_args: + args = tuple(a if a is not torch.float32 else torch.float16 for a in args) + if float32_dtype_in_kwargs: + kwargs["dtype"] = torch.float16 + return super().create_proxy( + kind, target, args, kwargs, name=name, type_expr=type_expr, proxy_factory_fn=proxy_factory_fn + ) + + # TODO: keep until transformers 4.23.2 is released. + def _generate_dummy_input( + self, model: "PreTrainedModel", input_name: str, shape: List[int] + ) -> Dict[str, torch.Tensor]: + input_dict = {} + model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__ + if input_name == "labels": + if model_class_name in get_values(MODEL_FOR_CTC_MAPPING_NAMES): + input_dict["labels"] = torch.zeros(*shape, dtype=torch.float, device=model.device) + if model_class_name in get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES): + input_dict["labels"] = torch.zeros(shape[0], dtype=torch.long, device=model.device) + else: + input_dict = super()._generate_dummy_input(model, input_name, shape) + return input_dict def symbolic_trace_with_pipelined_tracer( @@ -139,13 +172,18 @@ def symbolic_trace_with_pipelined_tracer( return traced +def cast_traced_model_to_proper_class(model: torch.nn.Module, traced: torch.fx.GraphModule): + type_ = type(f"Traced{model.__class__.__name__}", (torch.fx.GraphModule, model.__class__), {}) + traced.__class__ = type_ + traced.recompile() + + def symbolic_trace_pipelined_model(pipelined_model: PipelineMixin) -> PipelineMixin: if isinstance(pipelined_model, torch.fx.GraphModule): return pipelined_model transformers_class = None bases = list(pipelined_model.__class__.__bases__) - import inspect while bases: base = bases.pop(0) @@ -160,6 +198,5 @@ def symbolic_trace_pipelined_model(pipelined_model: PipelineMixin) -> PipelineMi traced = symbolic_trace_with_pipelined_tracer( pipelined_model, input_names=pipelined_model.input_names_for_symbolic_trace ) - type_ = type(f"Traced{pipelined_model.__class__.__name__}", (torch.fx.GraphModule, pipelined_model.__class__), {}) - traced.__class__ = type_ + cast_traced_model_to_proper_class(pipelined_model, traced) return traced diff --git a/optimum/graphcore/models/bart/modeling_bart.py b/optimum/graphcore/models/bart/modeling_bart.py index 6649468d6..8f8c32ca8 100644 --- a/optimum/graphcore/models/bart/modeling_bart.py +++ b/optimum/graphcore/models/bart/modeling_bart.py @@ -19,8 +19,8 @@ import transformers from optimum.utils import logging from transformers import BartForConditionalGeneration, BartForSequenceClassification -from transformers.models.bart.modeling_bart import BartAttention from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqSequenceClassifierOutput +from transformers.models.bart.modeling_bart import BartAttention from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, ReversibleTransformation, compose from ...fx.transformations import ( @@ -51,9 +51,8 @@ ] _NON_REVERSIBLE_TRANSFORMATIONS = [ - ClipValuesSymmetric(1e4, exclude_targets=["view"]), + ClipValuesSymmetric(10000, exclude_targets=["view"]), ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), - TupleOutput(), ] @@ -80,9 +79,10 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :] - inverted_mask = 1.0 - expanded_mask + inverted_mask = 1 - expanded_mask inverted_mask = -float("inf") * inverted_mask + # inverted_mask = * inverted_mask return inverted_mask @@ -435,99 +435,99 @@ def deparallelize(self): self = composition(self, reverse=True) return self - # def forward( - # self, - # input_ids: torch.LongTensor = None, - # attention_mask: Optional[torch.Tensor] = None, - # decoder_input_ids: Optional[torch.LongTensor] = None, - # decoder_attention_mask: Optional[torch.LongTensor] = None, - # head_mask: Optional[torch.Tensor] = None, - # decoder_head_mask: Optional[torch.Tensor] = None, - # cross_attn_head_mask: Optional[torch.Tensor] = None, - # encoder_outputs: Optional[List[torch.FloatTensor]] = None, - # inputs_embeds: Optional[torch.FloatTensor] = None, - # decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - # labels: Optional[torch.LongTensor] = None, - # use_cache: Optional[bool] = None, - # output_attentions: Optional[bool] = None, - # output_hidden_states: Optional[bool] = None, - # return_dict: Optional[bool] = None, - # ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: - # r""" - # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - # Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - # config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - # """ - # return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # if labels is not None: - # use_cache = False - - # outputs = self.model( - # input_ids, - # attention_mask=attention_mask, - # decoder_input_ids=decoder_input_ids, - # decoder_attention_mask=decoder_attention_mask, - # head_mask=head_mask, - # decoder_head_mask=decoder_head_mask, - # cross_attn_head_mask=cross_attn_head_mask, - # encoder_outputs=encoder_outputs, - # inputs_embeds=inputs_embeds, - # decoder_inputs_embeds=decoder_inputs_embeds, - # use_cache=use_cache, - # output_attentions=output_attentions, - # output_hidden_states=output_hidden_states, - # return_dict=return_dict, - # ) - - # hidden_states = outputs[0] # last hidden state - # B, L, E = hidden_states.shape - - # eos_mask = torch.eq(input_ids, self.config.eos_token_id) - # # Static tensor shape version of hidden_states[eos_mask, :] - # eos_indices = eos_mask * torch.arange(L).unsqueeze(0) - # last_eos_index, _ = torch.max(eos_indices, dim=1) - # # torch.index_select requires a 1D tensor of indices - # last_eos_index += torch.arange(B) * L - # hidden_states = hidden_states.view(B * L, E) - # sentence_representation = torch.index_select(hidden_states, 0, last_eos_index) - - # logits = self.classification_head(sentence_representation) - - # loss = None - # if labels is not None: - # if self.config.problem_type is None: - # if self.config.num_labels == 1: - # self.config.problem_type = "regression" - # elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - # self.config.problem_type = "single_label_classification" - # else: - # self.config.problem_type = "multi_label_classification" - - # if self.config.problem_type == "regression": - # loss_fct = nn.MSELoss() - # if self.config.num_labels == 1: - # loss = loss_fct(logits.squeeze(), labels.squeeze()) - # else: - # loss = loss_fct(logits, labels) - # elif self.config.problem_type == "single_label_classification": - # loss_fct = nn.CrossEntropyLoss() - # loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) - # elif self.config.problem_type == "multi_label_classification": - # loss_fct = nn.BCEWithLogitsLoss() - # loss = loss_fct(logits, labels) - - # if not return_dict: - # output = (logits,) + outputs[1:] - # return ((loss,) + output) if loss is not None else output - - # return Seq2SeqSequenceClassifierOutput( - # loss=loss, - # logits=logits, - # past_key_values=outputs.past_key_values, - # decoder_hidden_states=outputs.decoder_hidden_states, - # decoder_attentions=outputs.decoder_attentions, - # cross_attentions=outputs.cross_attentions, - # encoder_last_hidden_state=outputs.encoder_last_hidden_state, - # encoder_hidden_states=outputs.encoder_hidden_states, - # encoder_attentions=outputs.encoder_attentions, - # ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] # last hidden state + B, L, E = hidden_states.shape + + eos_mask = torch.eq(input_ids, self.config.eos_token_id) + # Static tensor shape version of hidden_states[eos_mask, :] + eos_indices = eos_mask * torch.arange(L).unsqueeze(0) + last_eos_index, _ = torch.max(eos_indices, dim=1) + # torch.index_select requires a 1D tensor of indices + last_eos_index += torch.arange(B) * L + hidden_states = hidden_states.view(B * L, E) + sentence_representation = torch.index_select(hidden_states, 0, last_eos_index) + + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = nn.MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) diff --git a/optimum/graphcore/models/bert/modeling_bert.py b/optimum/graphcore/models/bert/modeling_bert.py index 2db8aee3a..c91aee659 100644 --- a/optimum/graphcore/models/bert/modeling_bert.py +++ b/optimum/graphcore/models/bert/modeling_bert.py @@ -20,17 +20,17 @@ import poptorch from optimum.utils import logging from scipy.stats import truncnorm +from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput from transformers.models.bert.modeling_bert import ( BertForMaskedLM, BertForMultipleChoice, BertForPreTraining, + BertForPreTrainingOutput, BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, - BertForPreTrainingOutput, ) from transformers.utils.fx import _gen_constructor_wrapper -from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose from ...fx.transformations import ( @@ -61,7 +61,6 @@ _NON_REVERSIBLE_TRANSFORMATIONS = [ ClipValuesSymmetric(1e4, exclude_targets=["view"]), ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), - TupleOutput(), ] diff --git a/optimum/graphcore/models/convnext/modeling_convnext.py b/optimum/graphcore/models/convnext/modeling_convnext.py index 9391bd295..cd4af8362 100644 --- a/optimum/graphcore/models/convnext/modeling_convnext.py +++ b/optimum/graphcore/models/convnext/modeling_convnext.py @@ -13,7 +13,12 @@ # limitations under the License. import torch from torch import nn -from transformers.models.convnext.modeling_convnext import ConvNextLayer, ConvNextLayerNorm, ConvNextForImageClassification + +from transformers.models.convnext.modeling_convnext import ( + ConvNextForImageClassification, + ConvNextLayer, + ConvNextLayerNorm, +) from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose from ....utils import logging @@ -26,7 +31,6 @@ TupleOutput, ) from ...fx.utils import symbolic_trace_pipelined_model - from ...modeling_utils import PipelineMixin, get_layer_ipu, register from .optimized_convnextlayer import OptimizedConvNextLayer @@ -42,7 +46,6 @@ _NON_REVERSIBLE_TRANSFORMATIONS = [ ClipValuesSymmetric(1e4, exclude_targets=["view"]), ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), - TupleOutput(), ] @@ -82,18 +85,12 @@ def get_transformations(self): log_insertions = self.ipu_config.log_insertions layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) transformations = [ - AddPoptorchBlock( - "Embedding", 0, r"convnext.embeddings", log_insertions=log_insertions - ), + AddPoptorchBlock("Embedding", 0, r"convnext.embeddings", log_insertions=log_insertions), AddPoptorchBlocksInSeries( "Encoder", layer_ipu, r"convnext.encoder.stages.[0-9]+.layers.[0-9]+", log_insertions=log_insertions ), - AddPoptorchBlock( - "LayerNorm", layer_ipu[-1], r"convnext.layernorm", log_insertions=log_insertions - ), - AddPoptorchBlock( - "Classifier", layer_ipu[-1], r"classifier", log_insertions=log_insertions - ), + AddPoptorchBlock("LayerNorm", layer_ipu[-1], r"convnext.layernorm", log_insertions=log_insertions), + AddPoptorchBlock("Classifier", layer_ipu[-1], r"classifier", log_insertions=log_insertions), ] if self.ipu_config.recompute_checkpoint_every_layer: transformations += [ @@ -113,10 +110,10 @@ def parallelize(self): for layer in stage.layers: layer.__class__ = OptimizedConvNextLayer - # Enable autocast for ConvNextLayerNorm because computation cannot happen in fp16 - for mod in self.modules(): - if isinstance(mod, ConvNextLayerNorm): - mod.__class__ = IPUConvNextLayerNorm + # # Enable autocast for ConvNextLayerNorm because computation cannot happen in fp16 + # for mod in self.modules(): + # if isinstance(mod, ConvNextLayerNorm): + # mod.__class__ = IPUConvNextLayerNorm traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() @@ -135,11 +132,6 @@ def deparallelize(self): if isinstance(mod, IPUConvNextLayerNorm): mod.__class__ = ConvNextLayerNorm - # Switch back to non-optimized ConvNextLayer - for stage in self.convnext.encoder.stages: - for layer in stage.layers: - layer.__class__ = ConvNextLayer - transformations = self.get_transformations() transformations += _OPTIMIZATION_TRANSFORMATIONS composition = compose(*transformations) diff --git a/optimum/graphcore/models/deberta/modeling_deberta.py b/optimum/graphcore/models/deberta/modeling_deberta.py index 23ac37e5e..ea2d32404 100644 --- a/optimum/graphcore/models/deberta/modeling_deberta.py +++ b/optimum/graphcore/models/deberta/modeling_deberta.py @@ -14,21 +14,26 @@ import math import operator -from typing import Optional, Union, Tuple +from typing import Optional, Tuple, Union import torch import torch.nn as nn import poptorch import transformers -from transformers import DebertaForMaskedLM, DebertaForQuestionAnswering, DebertaForSequenceClassification, DebertaForTokenClassification +from transformers import ( + DebertaForMaskedLM, + DebertaForQuestionAnswering, + DebertaForSequenceClassification, + DebertaForTokenClassification, +) +from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput from transformers.models.deberta.modeling_deberta import ( DebertaEncoder, DisentangledSelfAttention, StableDropout, build_relative_position, ) -from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput from transformers.utils.fx import _gen_constructor_wrapper from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose @@ -45,7 +50,7 @@ VocabEmbeddingToSerializedEmbedding, ) from ...fx.utils import symbolic_trace_pipelined_model -from ...modeling_utils import PipelineMixin, get_layer_ipu, register, OnehotGather +from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register logger = logging.get_logger(__name__) @@ -59,7 +64,6 @@ _NON_REVERSIBLE_TRANSFORMATIONS = [ ClipValuesSymmetric(1e4, exclude_targets=["view"]), ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), - TupleOutput(), ] @@ -342,7 +346,7 @@ def parallelize(self): torch.nn.functional.one_hot = orig transformations = self.get_transformations() transformations += _OPTIMIZATION_TRANSFORMATIONS - composition = compose(*transformations, inplace=True) + composition = compose(*transformations) non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) traced = composition(traced) traced = non_reversible_composition(traced) diff --git a/optimum/graphcore/models/distilbert/modeling_distilbert.py b/optimum/graphcore/models/distilbert/modeling_distilbert.py index 25213403e..e64a27a5b 100644 --- a/optimum/graphcore/models/distilbert/modeling_distilbert.py +++ b/optimum/graphcore/models/distilbert/modeling_distilbert.py @@ -20,9 +20,17 @@ import torch.nn.functional as F import poptorch - +from optimum.utils import logging +from transformers import ( + DistilBertForMaskedLM, + DistilBertForMultipleChoice, + DistilBertForQuestionAnswering, + DistilBertForSequenceClassification, + DistilBertForTokenClassification, +) from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention + from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose from ...fx.transformations import ( AddPoptorchBlock, @@ -38,14 +46,6 @@ ) from ...fx.utils import symbolic_trace_pipelined_model from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register -from optimum.utils import logging -from transformers import ( - DistilBertForMaskedLM, - DistilBertForMultipleChoice, - DistilBertForQuestionAnswering, - DistilBertForSequenceClassification, - DistilBertForTokenClassification, -) logger = logging.get_logger(__name__) @@ -59,7 +59,6 @@ _NON_REVERSIBLE_TRANSFORMATIONS = [ ClipValuesSymmetric(1e4, exclude_targets=["view"]), ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), - TupleOutput(), ] @@ -143,7 +142,7 @@ def get_transformations(self): AddPoptorchBlock("Embedding", 0, "distilbert.embeddings", log_insertions=log_insertions), OutlineAttribute("distilbert.embeddings.LayerNorm", "Embedding"), AddPoptorchBlocksInSeries( - "Encoder", layer_ipu, r"distilbert.encoder.layer.[0-9]+", log_insertions=log_insertions + "Encoder", layer_ipu, r"distilbert.transformer.layer.[0-9]+", log_insertions=log_insertions ), # Only one of the following AddPoptorchBlock, will actually add a block. AddPoptorchBlock("Classifier Output", last_ipu, "classifier", log_insertions=log_insertions), @@ -152,7 +151,8 @@ def get_transformations(self): if self.ipu_config.recompute_checkpoint_every_layer: transformations.append( RecomputationCheckpoint( - "distilbert.encoder.layer.[0-9]+", to_exclude=f"distilbert.encoder.layer.{self.config.num_hidden_layers - 1}" + "distilbert.transformer.layer.[0-9]+", + to_exclude=f"distilbert.transformer.layer.{self.config.num_hidden_layers - 1}", ) ) if self.ipu_config.embedding_serialization_factor > 1: @@ -218,7 +218,8 @@ def get_transformations(self): if self.ipu_config.recompute_checkpoint_every_layer: transformations.append( RecomputationCheckpoint( - "distilbert.encoder.layer.[0-9]+", to_exclude=f"distilbert.encoder.layer.{self.config.num_hidden_layers - 1}" + "distilbert.encoder.layer.[0-9]+", + to_exclude=f"distilbert.encoder.layer.{self.config.num_hidden_layers - 1}", ) ) if self.ipu_config.embedding_serialization_factor > 1: @@ -332,7 +333,6 @@ class PipelinedDistilBertForSequenceClassification(DistilBertForSequenceClassifi @register(DistilBertForQuestionAnswering) class PipelinedDistilBertForQuestionAnswering(DistilBertForQuestionAnswering, DistilBertPipelineMixin): - def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/optimum/graphcore/models/gpt2/modeling_gpt2.py b/optimum/graphcore/models/gpt2/modeling_gpt2.py index ecd649ca7..4e4d8d881 100644 --- a/optimum/graphcore/models/gpt2/modeling_gpt2.py +++ b/optimum/graphcore/models/gpt2/modeling_gpt2.py @@ -53,7 +53,6 @@ _NON_REVERSIBLE_TRANSFORMATIONS = [ ClipValuesSymmetric(1e4, exclude_targets=["view"]), ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), - TupleOutput(), ] diff --git a/optimum/graphcore/models/hubert/modeling_hubert.py b/optimum/graphcore/models/hubert/modeling_hubert.py index c337911ec..f5a8749d1 100644 --- a/optimum/graphcore/models/hubert/modeling_hubert.py +++ b/optimum/graphcore/models/hubert/modeling_hubert.py @@ -41,7 +41,6 @@ _NON_REVERSIBLE_TRANSFORMATIONS = [ ClipValuesSymmetric(1e4, exclude_targets=["view"]), ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), - TupleOutput(), ] diff --git a/optimum/graphcore/models/lxmert/modeling_lxmert.py b/optimum/graphcore/models/lxmert/modeling_lxmert.py index ce4a5eab9..43670d66f 100644 --- a/optimum/graphcore/models/lxmert/modeling_lxmert.py +++ b/optimum/graphcore/models/lxmert/modeling_lxmert.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union, Tuple +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F from transformers.models.lxmert.modeling_lxmert import LxmertForQuestionAnswering, LxmertForQuestionAnsweringOutput + from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose from ....utils import logging from ...fx.transformations import ( @@ -42,7 +43,6 @@ _NON_REVERSIBLE_TRANSFORMATIONS = [ ClipValuesSymmetric(1e4, exclude_targets=["view"]), ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), - TupleOutput(), ] diff --git a/optimum/graphcore/models/roberta/modeling_roberta.py b/optimum/graphcore/models/roberta/modeling_roberta.py index 53235c9ca..3183b6ac4 100644 --- a/optimum/graphcore/models/roberta/modeling_roberta.py +++ b/optimum/graphcore/models/roberta/modeling_roberta.py @@ -57,7 +57,6 @@ _NON_REVERSIBLE_TRANSFORMATIONS = [ ClipValuesSymmetric(1e4, exclude_targets=["view"]), ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), - TupleOutput(), ] diff --git a/optimum/graphcore/models/t5/modeling_t5.py b/optimum/graphcore/models/t5/modeling_t5.py index 59b24ef61..18ce59c92 100644 --- a/optimum/graphcore/models/t5/modeling_t5.py +++ b/optimum/graphcore/models/t5/modeling_t5.py @@ -52,7 +52,6 @@ _NON_REVERSIBLE_TRANSFORMATIONS = [ ClipValuesSymmetric(1e4, exclude_targets=["view"]), ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), - TupleOutput(), ] @@ -147,9 +146,9 @@ def parallelize(self): ``` """ PipelineMixin.parallelize(self) - for mod in self.modules(): - if isinstance(mod, T5LayerNorm): - mod.forward = poptorch.autocast(enabled=True)(mod.forward) + # for mod in self.modules(): + # if isinstance(mod, T5LayerNorm): + # mod.forward = poptorch.autocast(enabled=True)(mod.forward) traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() transformations += _OPTIMIZATION_TRANSFORMATIONS @@ -174,37 +173,113 @@ def deparallelize(self): self = composition(self, reverse=True) return self - # def train(self, mode: bool = True) -> "PipelinedT5ForConditionalGeneration": - # mod = super(T5ForConditionalGeneration, self).train(mode=mode) - # # TODO: enable that once generation is supported. - # # mod.forward = mod._forward_for_train if mode else mod._forward_for_generate - # mod.forward = mod._forward_for_train - # return mod - - # def _forward_for_train(self, input_ids, attention_mask, decoder_input_ids, labels=None): - # outputs = super().forward( - # input_ids=input_ids, - # attention_mask=attention_mask, - # decoder_input_ids=decoder_input_ids, - # labels=labels, - # use_cache=False, - # return_dict=False, - # ) - # # Only returning the loss to make the communication between the host and the device faster. - # return outputs[0:1] - - # def _forward_for_generate(self, encoder_outputs, decoder_input_ids, attention_mask, labels=None): - # outputs = super().forward( - # encoder_outputs=encoder_outputs, - # attention_mask=attention_mask, - # decoder_input_ids=decoder_input_ids, - # return_dict=False, - # use_cache=False, - # labels=labels, - # ) - # # Only returning the loss (if labels is provided) and the logits. - # if labels is None: - # return outputs[:1] - # return outputs[:2] - - # forward = _forward_for_train + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + # Only returning the loss to make the communication between the host and the device faster. + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return (loss,) if labels is not None else output + + if loss is not None: + return Seq2SeqLMOutput( + loss=loss, + ) + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) diff --git a/optimum/graphcore/models/vit/modeling_vit.py b/optimum/graphcore/models/vit/modeling_vit.py index ef36046db..45e8d0a40 100644 --- a/optimum/graphcore/models/vit/modeling_vit.py +++ b/optimum/graphcore/models/vit/modeling_vit.py @@ -40,7 +40,6 @@ _NON_REVERSIBLE_TRANSFORMATIONS = [ ClipValuesSymmetric(1e4, exclude_targets=["view"]), ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), - TupleOutput(), ] diff --git a/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py b/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py index 15a9b6546..344225bc4 100755 --- a/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py @@ -26,8 +26,8 @@ Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2ForCTC, - Wav2Vec2GumbelVectorQuantizer, Wav2Vec2ForPreTrainingOutput, + Wav2Vec2GumbelVectorQuantizer, ) from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose @@ -42,7 +42,6 @@ ) from ...fx.utils import symbolic_trace_pipelined_model from ...modeling_utils import PipelineMixin, get_layer_ipu, register - from .ipu_gumbel_vector_quantizer import IPUWav2Vec2GumbelVectorQuantizer from .ipu_layer_drop import IPUWav2Vec2Adapter, IPUWav2Vec2Encoder, IPUWav2Vec2EncoderStableLayerNorm @@ -58,7 +57,6 @@ _NON_REVERSIBLE_TRANSFORMATIONS = [ ClipValuesSymmetric(1e4, exclude_targets=["view"]), ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), - TupleOutput(), ] @@ -100,14 +98,25 @@ class Wav2Vec2PipelineMixin(PipelineMixin): def get_transformations(self): log_insertions = self.ipu_config.log_insertions layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - feature_extractor_conv_layers_ipu = layer_ipu[:self.config.num_feat_extract_layers] + feature_extractor_conv_layers_ipu = layer_ipu[: self.config.num_feat_extract_layers] transformations = [ AddPoptorchBlocksInSeries( - "Conv", feature_extractor_conv_layers_ipu, r"wav2vec2.feature_extractor.conv_layers.[0-9]+", log_insertions=log_insertions + "Conv", + feature_extractor_conv_layers_ipu, + r"wav2vec2.feature_extractor.conv_layers.[0-9]+", + log_insertions=log_insertions, + ), + AddPoptorchBlock( + "Positional Embedding", + layer_ipu[self.config.num_feat_extract_layers], + "wav2vec2.encoder.pos_conv_embed", + log_insertions=log_insertions, ), - AddPoptorchBlock("Positional Embedding", layer_ipu[self.config.num_feat_extract_layers], "wav2vec2.encoder.pos_conv_embed", log_insertions=log_insertions), AddPoptorchBlocksInSeries( - "Encoder", layer_ipu[self.config.num_feat_extract_layers + 1:], r"wav2vec2.encoder.layers.[0-9]+", log_insertions=log_insertions + "Encoder", + layer_ipu[self.config.num_feat_extract_layers + 1 :], + r"wav2vec2.encoder.layers.[0-9]+", + log_insertions=log_insertions, ), ] if self.ipu_config.recompute_checkpoint_every_layer: @@ -199,7 +208,9 @@ def get_transformations(self): transformations += [ AddPoptorchBlock("Project Hidden", layer_ipu[start_idx], "project_hid", log_insertions=log_insertions), AddPoptorchBlock("Quantizer", layer_ipu[start_idx + 1], "quantizer", log_insertions=log_insertions), - AddPoptorchBlock("Project Quantizer", layer_ipu[start_idx + 2], "project_q", log_insertions=log_insertions), + AddPoptorchBlock( + "Project Quantizer", layer_ipu[start_idx + 2], "project_q", log_insertions=log_insertions + ), ] return transformations diff --git a/optimum/graphcore/trainer.py b/optimum/graphcore/trainer.py index 0ecd5c090..3deffb594 100644 --- a/optimum/graphcore/trainer.py +++ b/optimum/graphcore/trainer.py @@ -191,7 +191,7 @@ class IPUTrainer: be able to choose different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc). **Note: this feature is not supported for now.** - compute_metrics (`Callable[[~transformers.trainer_utils.EvalPrediction], Dict]`, *optional*): + compute_metrics (`Callable[[transformers.trainer_utils.EvalPrediction], Dict]`, *optional*): The function that will be used to compute metrics at evaluation. Must take a [`~transformers.trainer_utils.EvalPrediction`] and return a dictionary string to metric values. callbacks (List of [`transformers.trainer_callback.TrainerCallback`], *optional*): @@ -485,7 +485,6 @@ def compile_model( model = model.parallelize() if not self.args.fp32: model.half() - import pdb; pdb.set_trace() if training: self.model = model else: @@ -1090,14 +1089,11 @@ def _inner_training_loop( if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - if trial is not None: raise ValueError("Hyperparameter tuning is not supported by the IPUTrainer.") trial = None self.state.is_hyper_param_search = trial is not None - - # self.training_model = self.wrap_model(self.model) self.traning_model = self.compile_model(next(iter(train_dataloader)), training=True) diff --git a/tests/test_pipelined_models.py b/tests/test_pipelined_models.py index 31c32825f..49c4a8bd2 100644 --- a/tests/test_pipelined_models.py +++ b/tests/test_pipelined_models.py @@ -12,13 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy +import inspect from unittest import TestCase import torch from datasets import load_dataset from PIL import Image -from torch.nn.utils.weight_norm import WeightNorm import requests import transformers @@ -175,10 +174,14 @@ def test_pretrained_and_pipelined_models_match( pretrained_model_outputs = pretrained_model(**inputs, return_dict=True) # The forward method can be different in train and eval mode for some models (seq2seq for instance), so we make # sure to use the proper one. - pipelined_forward_function = getattr(pipelined_model, "_forward_for_train", pipelined_model.forward) + # pipelined_forward_function = getattr(pipelined_model, "_forward_for_train", pipelined_model.forward) - pipelined_model.parallelize() - pipelined_model_outputs = pipelined_forward_function(**inputs, return_dict=True) + input_names = [p for p in inspect.signature(pipelined_model.forward).parameters if p in inputs] + inputs_values = [inputs[k] for k in input_names] + pipelined_model.input_names_for_symbolic_trace = input_names + + pipelined_model = pipelined_model.parallelize() + pipelined_model_outputs = pipelined_model(*inputs_values) for idx, k in enumerate(pretrained_model_outputs.keys()): pretrained_output, pipelined_output = pretrained_model_outputs[k], pipelined_model_outputs[k] # Handle tuple outputs. Outputs such as past_key_values are returned as tuples. @@ -202,8 +205,8 @@ def test_pretrained_and_pipelined_models_match( f"Pretrained and pipelined model {idx}th outputs do not match, max difference = {(pretrained_output - pipelined_output).abs().max()}", ) - pipelined_model.deparallelize() - pipelined_model_outputs = pipelined_forward_function(**inputs) + pipelined_model = pipelined_model.deparallelize() + pipelined_model_outputs = pipelined_model(*inputs_values) for idx, k in enumerate(pretrained_model_outputs.keys()): pretrained_output, pipelined_output = pretrained_model_outputs[k], pipelined_model_outputs[k] # Handle tuple outputs. Outputs such as past_key_values are returned as tuples. @@ -227,26 +230,26 @@ def test_pretrained_and_pipelined_models_match( f"Pretrained and pipelined model {idx}th outputs do not match, max difference = {(pretrained_output - pipelined_output).abs().max()}", ) - @parameterized.expand(MODELS_TO_TEST) - def test_parallelize_deparallelize( - self, test_name, model_name_or_path, ipu_config_name_or_path, pretrained_class, pipelined_class, config_class - ): - ipu_config = IPUConfig.from_pretrained(ipu_config_name_or_path) - model = pipelined_class.from_pretrained_transformers(model_name_or_path, ipu_config) - - # Remove the weight-norm hook, if present, because it doesn't work with deepcopy - # https://github.com/pytorch/pytorch/issues/28594 - for module in model.modules(): - for _, hook in module._forward_pre_hooks.items(): - if isinstance(hook, WeightNorm): - delattr(module, hook.name) - - modules_before = copy.deepcopy(model).modules() - model.parallelize() - model.deparallelize() - modules_after = copy.deepcopy(model).modules() - # Confirm that parallelize then deparallelize won't change the model's modules - for mod_before, mod_after in zip(modules_before, modules_after): - self.assertEqual(type(mod_before), type(mod_after)) - - model.parallelize() + # @parameterized.expand(MODELS_TO_TEST) + # def test_parallelize_deparallelize( + # self, test_name, model_name_or_path, ipu_config_name_or_path, pretrained_class, pipelined_class, config_class + # ): + # ipu_config = IPUConfig.from_pretrained(ipu_config_name_or_path) + # model = pipelined_class.from_pretrained_transformers(model_name_or_path, ipu_config) + + # # Remove the weight-norm hook, if present, because it doesn't work with deepcopy + # # https://github.com/pytorch/pytorch/issues/28594 + # for module in model.modules(): + # for _, hook in module._forward_pre_hooks.items(): + # if isinstance(hook, WeightNorm): + # delattr(module, hook.name) + + # modules_before = copy.deepcopy(model).modules() + # model.parallelize() + # model.deparallelize() + # modules_after = copy.deepcopy(model).modules() + # # Confirm that parallelize then deparallelize won't change the model's modules + # for mod_before, mod_after in zip(modules_before, modules_after): + # self.assertEqual(type(mod_before), type(mod_after)) + + # model.parallelize() diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 87439834f..8ad04d0b9 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -14,7 +14,6 @@ # limitations under the License. import dataclasses -import gc import math import os import random @@ -29,9 +28,10 @@ from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token from optimum.graphcore import IPUConfig, IPUTrainingArguments +from optimum.graphcore.fx.utils import cast_traced_model_to_proper_class from optimum.utils import logging from requests.exceptions import HTTPError -from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, is_torch_available +from transformers import IntervalStrategy, PretrainedConfig, is_torch_available from transformers.file_utils import WEIGHTS_NAME from transformers.testing_utils import ( ENDPOINT_STAGING, @@ -42,25 +42,17 @@ get_gpu_count, get_tests_dir, is_staging_test, - require_optuna, - require_ray, require_sentencepiece, - require_sigopt, require_tokenizers, require_torch, - require_torch_gpu, - require_torch_multi_gpu, - require_torch_non_multi_gpu, - require_torch_up_to_2_gpus, - slow, ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR -from transformers.utils.hp_naming import TrialShortNamer if is_torch_available(): import torch from torch import nn + from torch.fx import symbolic_trace from torch.utils.data import IterableDataset import poptorch @@ -170,6 +162,15 @@ def __iter__(self): for i in range(len(self.dataset)): yield self.dataset[i] + def symbolic_trace_and_cast(model: torch.nn.Module): + if isinstance(model, torch.fx.GraphModule): + return model + traced = symbolic_trace(model) + for name, value in model.__dict__.items(): + setattr(traced, name, value) + cast_traced_model_to_proper_class(model, traced) + return traced + class RegressionModel(nn.Module): def __init__(self, a=0, b=0, double_output=False): super().__init__() @@ -185,6 +186,9 @@ def forward(self, input_x, labels=None, labels_2=None): loss = nn.functional.mse_loss(y, labels) return (loss, y, y) if self.double_output else (loss, y) + def parallelize(self): + return symbolic_trace_and_cast(self) + class RegressionDictModel(nn.Module): def __init__(self, a=0, b=0): super().__init__() @@ -199,6 +203,9 @@ def forward(self, input_x, labels=None): result["loss"] = nn.functional.mse_loss(y, labels) return result + def parallelize(self): + return symbolic_trace_and_cast(self) + class RegressionPreTrainedModel(PreTrainedModel): config_class = RegressionModelConfig base_model_prefix = "regression" @@ -216,6 +223,9 @@ def forward(self, input_x, labels=None, labels_2=None): loss = nn.functional.mse_loss(y, labels) return (loss, y, y) if self.double_output else (loss, y) + def parallelize(self): + return symbolic_trace_and_cast(self) + class RegressionRandomPreTrainedModel(PreTrainedModel): config_class = RegressionModelConfig base_model_prefix = "regression" @@ -239,6 +249,9 @@ def forward(self, input_x, labels=None): loss = nn.functional.mse_loss(y, labels) return (loss, y) + def parallelize(self): + return symbolic_trace_and_cast(self) + class TstLayer(nn.Module): def __init__(self, hidden_size): super().__init__() From 0b2342821d35d0d3326339e5bb48c2a2d2238ef4 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 19 Oct 2022 14:46:07 +0200 Subject: [PATCH 21/33] All tests but test_examples are passing --- optimum/graphcore/fx/utils.py | 40 ++++++++++++------- .../models/convnext/modeling_convnext.py | 8 ++-- .../graphcore/models/gpt2/modeling_gpt2.py | 2 - .../models/lxmert/modeling_lxmert.py | 9 +++-- tests/test_pipelined_models.py | 6 +++ 5 files changed, 42 insertions(+), 23 deletions(-) diff --git a/optimum/graphcore/fx/utils.py b/optimum/graphcore/fx/utils.py index 8d5d9f0e4..01b5f772f 100644 --- a/optimum/graphcore/fx/utils.py +++ b/optimum/graphcore/fx/utils.py @@ -50,6 +50,7 @@ def __init__(self, autowrap_modules=(math,), autowrap_functions=()): self.ops_to_wrap = [] self.current_module_qualified_name = ["root"] self.current_module_type = ["root"] + self.root_is_in_half_precision = False def register_op_to_wrap(self, name, wrapper, orig_op): self.ops_to_wrap.append((name, wrapper, orig_op)) @@ -95,20 +96,20 @@ def call_module(self, m, forward, args, kwargs): return proxy def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None): - # TODO: how to handle the case where the model is ran in full-precision? - float32_dtype_in_args = any(a is torch.float32 for a in args) - float32_dtype_in_kwargs = kwargs.get("dtype", None) is torch.float32 - node_types_to_inspect = [ - ("call_method", "to"), - ("call_function", torch.full), - ] - torch_methods_to_patched_version = {orig: wrapped for (orig, wrapped) in self.patched_torch_methods.values()} - for (k, t) in node_types_to_inspect: - if kind == k and target == torch_methods_to_patched_version.get(t, t): - if float32_dtype_in_args: - args = tuple(a if a is not torch.float32 else torch.float16 for a in args) - if float32_dtype_in_kwargs: - kwargs["dtype"] = torch.float16 + if self.root_is_in_half_precision: + float32_dtype_in_args = any(a is torch.float32 for a in args) + float32_dtype_in_kwargs = kwargs.get("dtype", None) is torch.float32 + node_types_to_inspect = [ + ("call_method", "to"), + ("call_function", torch.full), + ] + torch_methods_to_patched_version = {orig: wrapped for (orig, wrapped) in self.patched_torch_methods.values()} + for (k, t) in node_types_to_inspect: + if kind == k and target == torch_methods_to_patched_version.get(t, t): + if float32_dtype_in_args: + args = tuple(a if a is not torch.float32 else torch.float16 for a in args) + if float32_dtype_in_kwargs: + kwargs["dtype"] = torch.float16 return super().create_proxy( kind, target, args, kwargs, name=name, type_expr=type_expr, proxy_factory_fn=proxy_factory_fn ) @@ -128,6 +129,17 @@ def _generate_dummy_input( input_dict = super()._generate_dummy_input(model, input_name, shape) return input_dict + def trace(self, *args, **kwargs) -> torch.fx.Graph: + root = args[0] + if not isinstance(root, torch.nn.Module): + # Cannot infer easily. + self.root_is_in_half_precision = False + else: + self.root_is_in_half_precision = any(p.dtype is torch.float16 for p in root.parameters()) + graph = super().trace(*args, **kwargs) + self.root_is_in_half_precision = False + return graph + def symbolic_trace_with_pipelined_tracer( model: PipelineMixin, diff --git a/optimum/graphcore/models/convnext/modeling_convnext.py b/optimum/graphcore/models/convnext/modeling_convnext.py index cd4af8362..fc61750a8 100644 --- a/optimum/graphcore/models/convnext/modeling_convnext.py +++ b/optimum/graphcore/models/convnext/modeling_convnext.py @@ -110,10 +110,10 @@ def parallelize(self): for layer in stage.layers: layer.__class__ = OptimizedConvNextLayer - # # Enable autocast for ConvNextLayerNorm because computation cannot happen in fp16 - # for mod in self.modules(): - # if isinstance(mod, ConvNextLayerNorm): - # mod.__class__ = IPUConvNextLayerNorm + # Enable autocast for ConvNextLayerNorm because computation cannot happen in fp16 + for mod in self.modules(): + if isinstance(mod, ConvNextLayerNorm): + mod.__class__ = IPUConvNextLayerNorm traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() diff --git a/optimum/graphcore/models/gpt2/modeling_gpt2.py b/optimum/graphcore/models/gpt2/modeling_gpt2.py index 4e4d8d881..52dc67921 100644 --- a/optimum/graphcore/models/gpt2/modeling_gpt2.py +++ b/optimum/graphcore/models/gpt2/modeling_gpt2.py @@ -98,7 +98,6 @@ def get_transformations(self): ) ) if self.ipu_config.embedding_serialization_factor > 1: - self.resize_vocab(False) transformations.append(VocabEmbeddingToSerializedEmbedding()) return transformations @@ -159,7 +158,6 @@ def get_transformations(self): ) ) if self.ipu_config.embedding_serialization_factor > 1: - self.resize_vocab(False) transformations += [ LinearToSerializedLinear("lm_head"), TieWeights("transformer.wte", "lm_head"), diff --git a/optimum/graphcore/models/lxmert/modeling_lxmert.py b/optimum/graphcore/models/lxmert/modeling_lxmert.py index 43670d66f..c0ac46a01 100644 --- a/optimum/graphcore/models/lxmert/modeling_lxmert.py +++ b/optimum/graphcore/models/lxmert/modeling_lxmert.py @@ -36,7 +36,8 @@ _OPTIMIZATION_TRANSFORMATIONS = [ ChangeTrueDivToMulByInverse(), - MergeLinears(), + # TODO: Not working for now. + # MergeLinears(), # FuseBiasInLinear(), ] @@ -51,9 +52,11 @@ class PipelinedLxmertForQuestionAnswering(LxmertForQuestionAnswering, PipelineMi def get_transformations(self): log_insertions = self.ipu_config.log_insertions layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + # TODO: remove this line after testing. + layer_ipu = get_layer_ipu([0, 7, 7, 5]) language_layers_ipus = layer_ipu[: self.config.l_layers] - visual_layers_ipus = layer_ipu[self.config.l_layers : self.config.l_layers + self.r_layers] - cross_modality_layers_ipus = layer_ipu[self.config.l_layers + self.r_layers :] + visual_layers_ipus = layer_ipu[self.config.l_layers : self.config.l_layers + self.config.r_layers] + cross_modality_layers_ipus = layer_ipu[self.config.l_layers + self.config.r_layers :] transformations = [ AddPoptorchBlock("Embedding", 0, "lxmert.embeddings", log_insertions=log_insertions), diff --git a/tests/test_pipelined_models.py b/tests/test_pipelined_models.py index 49c4a8bd2..2a887c011 100644 --- a/tests/test_pipelined_models.py +++ b/tests/test_pipelined_models.py @@ -167,6 +167,12 @@ def test_pretrained_and_pipelined_models_match( ): config = config_class.from_pretrained(model_name_or_path) ipu_config = IPUConfig.from_pretrained(ipu_config_name_or_path) + if "gpt2" in model_name_or_path: + if pretrained_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(): + config.pad_token_id = 2 + if pretrained_class in MODEL_FOR_CAUSAL_LM_MAPPING.values(): + # Disabling it because otherwise we are resizing the vocab, which makes outputs comparison impossible. + ipu_config.embedding_serialization_factor = 1 pretrained_model = pretrained_class(config).eval() pipelined_model = pipelined_class.from_transformers(pretrained_model, ipu_config).eval() From a7abe141eb8e55aca24bca136429e3722739b0a3 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 20 Oct 2022 18:00:48 +0200 Subject: [PATCH 22/33] Add TransformationManager --- optimum/graphcore/fx/transformations.py | 12 +- optimum/graphcore/fx/utils.py | 3 +- optimum/graphcore/ipu_configuration.py | 7 + .../graphcore/models/bart/modeling_bart.py | 157 +++++--------- .../graphcore/models/bert/modeling_bert.py | 195 ++++++------------ .../models/deberta/modeling_deberta.py | 6 +- .../graphcore/models/gpt2/modeling_gpt2.py | 65 +----- .../graphcore/models/hubert/ipu_layer_drop.py | 4 +- .../models/hubert/modeling_hubert.py | 45 ++-- optimum/graphcore/models/t5/modeling_t5.py | 33 +-- optimum/graphcore/models/vit/modeling_vit.py | 29 +-- 11 files changed, 182 insertions(+), 374 deletions(-) diff --git a/optimum/graphcore/fx/transformations.py b/optimum/graphcore/fx/transformations.py index 3c034ef94..9a476e643 100644 --- a/optimum/graphcore/fx/transformations.py +++ b/optimum/graphcore/fx/transformations.py @@ -160,13 +160,13 @@ def __init__( self, min_value: float, max_value: float, - include_targets: Optional[List[Union[str, Callable]]] = None, - exclude_targets: Optional[List[Union[str, Callable]]] = None, + include_targets: Optional[Tuple[Union[str, Callable]]] = None, + exclude_targets: Optional[Tuple[Union[str, Callable]]] = None, ): self.min_value = min_value self.max_value = max_value - self.include_targets = include_targets if include_targets is not None else [] - self.exclude_targets = exclude_targets if exclude_targets is not None else [] + self.include_targets = include_targets if include_targets is not None else () + self.exclude_targets = exclude_targets if exclude_targets is not None else () def _clip_node_args(self, args): if isinstance(args, (tuple, list, set)): @@ -197,8 +197,8 @@ class ClipValuesSymmetric(ClipValues): def __init__( self, clip_value: float, - include_targets: Optional[List[Union[str, Callable]]] = None, - exclude_targets: Optional[List[Union[str, Callable]]] = None, + include_targets: Optional[Tuple[Union[str, Callable]]] = None, + exclude_targets: Optional[Tuple[Union[str, Callable]]] = None, ): if clip_value < 0: raise ValueError(f"The provided clip value must be equal or greater than 0, but here {clip_value}.") diff --git a/optimum/graphcore/fx/utils.py b/optimum/graphcore/fx/utils.py index 01b5f772f..d33a132ed 100644 --- a/optimum/graphcore/fx/utils.py +++ b/optimum/graphcore/fx/utils.py @@ -37,6 +37,7 @@ class PipelinedTracer(HFTracer): # TODO: keep this until transformers >= 4.23.2 _TORCH_METHODS_TO_PATCH = list(HFTracer._TORCH_METHODS_TO_PATCH) _TORCH_METHODS_TO_PATCH.append("clamp") + _TORCH_METHODS_TO_PATCH.append("rand") """ Tracer that enables tracing and transforming models to run them on IPUs. Compared to the HFTracer, this one adds the following features: @@ -125,7 +126,7 @@ def _generate_dummy_input( input_dict["labels"] = torch.zeros(*shape, dtype=torch.float, device=model.device) if model_class_name in get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES): input_dict["labels"] = torch.zeros(shape[0], dtype=torch.long, device=model.device) - else: + if "labels" not in input_dict: input_dict = super()._generate_dummy_input(model, input_name, shape) return input_dict diff --git a/optimum/graphcore/ipu_configuration.py b/optimum/graphcore/ipu_configuration.py index ac7ca8950..9fd5440d6 100644 --- a/optimum/graphcore/ipu_configuration.py +++ b/optimum/graphcore/ipu_configuration.py @@ -49,6 +49,11 @@ class IPUConfig(BaseConfig): **Note: This is an experimental feature and may not behave as expected.** executable_cache_dir (`str`, *optional*, defaults to `""`): Enables caching the compile executables to a directory. + optimization_level (`int`, *optional*, defaults to 1): + The optimization level to apply to the model before compilation. Three values are allowed: + - 0: No optimization is performed on the graph. + - 1: Optimizations that preserve the computation (same result as no optimization) are performed on the graph. + - 2: All the available optimizations are applied to the graph, potentially including approximations. > Parameters for controlling the batch size @@ -161,6 +166,8 @@ def __init__(self, **kwargs): self.log_insertions = kwargs.pop("log_insertions", False) + self.optimization_level = kwargs.pop("optimization_level", 1) + def _prepare_config_attribute_for_pod_type( self, config_attribute_name: str, config_attribute: Union[Any, Dict[str, Any]], pod_type: Optional[str] ) -> Any: diff --git a/optimum/graphcore/models/bart/modeling_bart.py b/optimum/graphcore/models/bart/modeling_bart.py index 8f8c32ca8..8d632610e 100644 --- a/optimum/graphcore/models/bart/modeling_bart.py +++ b/optimum/graphcore/models/bart/modeling_bart.py @@ -22,8 +22,8 @@ from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqSequenceClassifierOutput from transformers.models.bart.modeling_bart import BartAttention -from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, ReversibleTransformation, compose -from ...fx.transformations import ( +from ....fx.optimization import ReversibleTransformation, compose +from ...fx import ( AddPoptorchBlock, AddPoptorchBlocksInSeries, ClipValues, @@ -32,10 +32,10 @@ RecomputationCheckpoint, ShareEmbeddingComputation, TieWeights, - TupleOutput, VocabEmbeddingToSerializedEmbedding, + symbolic_trace_pipelined_model, + DEFAULT_TRANSFORMATION_MANAGER, ) -from ...fx.utils import symbolic_trace_pipelined_model from ...generation_utils import IPUGenerationMixin from ...modeling_utils import GenerationMethodsMixin, PipelineMixin, get_layer_ipu, register @@ -44,16 +44,8 @@ FLOAT16_LIMIT = 1e4 -_OPTIMIZATION_TRANSFORMATIONS = [ - ChangeTrueDivToMulByInverse(), - MergeLinears(), - # FuseBiasInLinear(), -] - -_NON_REVERSIBLE_TRANSFORMATIONS = [ - ClipValuesSymmetric(10000, exclude_targets=["view"]), - ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), -] +TRANSFORMATION_MANAGER = DEFAULT_TRANSFORMATION_MANAGER.without(ClipValuesSymmetric(1e4, exclude_targets=("view",))) +TRANSFORMATION_MANAGER.register(1, ClipValuesSymmetric(10000, exclude_targets=("view",)) def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): @@ -201,6 +193,53 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value +class BartPipelineMixin(PipelineMixin): + def parallelize(self): + """ + Transform the model to run in an IPU pipeline. + - Adds pipeline stages to the model + - (If enabled) Replaces the shared embedding with a SerializedEmbedding + - Adds recomputation checkpoints + + Recommended usage: + ``` + model = PipelinedBartForConditionalGeneration(config).parallelize().half() + ``` + """ + super().parallelize() + orig_make_causal_mask = transformers.models.bart.modeling_bart._make_causal_mask + orig_expand_mask = transformers.models.bart.modeling_bart._expand_mask + transformers.models.bart.modeling_bart._make_causal_mask = _make_causal_mask + transformers.models.bart.modeling_bart._expand_mask = _expand_mask + for mod in self.modules(): + if isinstance(mod, BartAttention): + mod.__class__ = _BartAttentionWithoutException + traced = symbolic_trace_pipelined_model(self) + transformers.models.bart.modeling_bart._make_causal_mask = orig_make_causal_mask + transformers.models.bart.modeling_bart._expand_mask = orig_expand_mask + transformations = self.get_transformations() + transformations += TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) + composition = compose(*transformations) + non_reversible_composition = TRANSFORMATION_MANAGER.compose_non_reversible_transformations(self.ipu_config.optimization_level) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced + + def deparallelize(self): + """ + Undo the changes to the model done by `parallelize`. + You should call this before doing `save_pretrained` so that the `model.state_dict` is + fully compatible with `transformers.BartForConditionalGeneration`. + """ + super().deparallelize() + transformations = self.get_transformations() + transformations = [t for t in transformations if isinstance(t, ReversibleTransformation)] + transformations += TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) + composition = compose(*transformations) + self = composition(self, reverse=True) + return self + + @register(BartForConditionalGeneration) class PipelinedBartForConditionalGeneration( GenerationMethodsMixin, BartForConditionalGeneration, PipelineMixin, IPUGenerationMixin @@ -247,51 +286,6 @@ def get_transformations(self): transformations += [ShareEmbeddingComputation()] return transformations - def parallelize(self): - """ - Transform the model to run in an IPU pipeline. - - Adds pipeline stages to the model - - (If enabled) Replaces the shared embedding with a SerializedEmbedding - - Adds recomputation checkpoints - - Recommended usage: - ``` - model = PipelinedBartForConditionalGeneration(config).parallelize().half() - ``` - """ - super().parallelize() - orig_make_causal_mask = transformers.models.bart.modeling_bart._make_causal_mask - orig_expand_mask = transformers.models.bart.modeling_bart._expand_mask - transformers.models.bart.modeling_bart._make_causal_mask = _make_causal_mask - transformers.models.bart.modeling_bart._expand_mask = _expand_mask - for mod in self.modules(): - if isinstance(mod, BartAttention): - mod.__class__ = _BartAttentionWithoutException - traced = symbolic_trace_pipelined_model(self) - transformers.models.bart.modeling_bart._make_causal_mask = orig_make_causal_mask - transformers.models.bart.modeling_bart._expand_mask = orig_expand_mask - transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS - composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) - traced = composition(traced) - traced = non_reversible_composition(traced) - return traced - - def deparallelize(self): - """ - Undo the changes to the model done by `parallelize`. - You should call this before doing `save_pretrained` so that the `model.state_dict` is - fully compatible with `transformers.BartForConditionalGeneration`. - """ - super().deparallelize() - transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS - transformations = [t for t in transformations if isinstance(t, ReversibleTransformation)] - composition = compose(*transformations) - self = composition(self, reverse=True) - return self - def forward( self, input_ids: torch.LongTensor = None, @@ -390,51 +384,6 @@ def get_transformations(self): transformations += [ShareEmbeddingComputation()] return transformations - def parallelize(self): - """ - Transform the model to run in an IPU pipeline. - - Adds pipeline stages to the model - - (If enabled) Replaces the shared embedding with a SerializedEmbedding - - Adds recomputation checkpoints - - Recommended usage: - ``` - model = PipelinedBartForConditionalGeneration(config).parallelize().half() - ``` - """ - super().parallelize() - orig_make_causal_mask = transformers.models.bart.modeling_bart._make_causal_mask - orig_expand_mask = transformers.models.bart.modeling_bart._expand_mask - transformers.models.bart.modeling_bart._make_causal_mask = _make_causal_mask - transformers.models.bart.modeling_bart._expand_mask = _expand_mask - for mod in self.modules(): - if isinstance(mod, BartAttention): - mod.__class__ = _BartAttentionWithoutException - traced = symbolic_trace_pipelined_model(self) - transformers.models.bart.modeling_bart._make_causal_mask = orig_make_causal_mask - transformers.models.bart.modeling_bart._expand_mask = orig_expand_mask - transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS - composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) - traced = composition(traced) - traced = non_reversible_composition(traced) - return traced - - def deparallelize(self): - """ - Undo the changes to the model done by `parallelize`. - You should call this before doing `save_pretrained` so that the `model.state_dict` is - fully compatible with `transformers.BartForConditionalGeneration`. - """ - super().deparallelize() - transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS - transformations = [t for t in transformations if isinstance(t, ReversibleTransformation)] - composition = compose(*transformations) - self = composition(self, reverse=True) - return self - def forward( self, input_ids: torch.LongTensor = None, diff --git a/optimum/graphcore/models/bert/modeling_bert.py b/optimum/graphcore/models/bert/modeling_bert.py index c91aee659..65816a103 100644 --- a/optimum/graphcore/models/bert/modeling_bert.py +++ b/optimum/graphcore/models/bert/modeling_bert.py @@ -32,8 +32,8 @@ ) from transformers.utils.fx import _gen_constructor_wrapper -from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose -from ...fx.transformations import ( +from ....fx.optimization import compose +from ...fx import ( AddPoptorchBlock, AddPoptorchBlocksInSeries, ClipValues, @@ -42,30 +42,76 @@ OutlineAttribute, RecomputationCheckpoint, TieWeights, - TupleOutput, VocabEmbeddingToSerializedEmbedding, + symbolic_trace_pipelined_model, + DEFAULT_TRANSFORMATION_MANAGER, ) -from ...fx.utils import symbolic_trace_pipelined_model from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register logger = logging.get_logger(__name__) -_OPTIMIZATION_TRANSFORMATIONS = [ - ChangeTrueDivToMulByInverse(), - MergeLinears(), - # FuseBiasInLinear(), -] +class BertPipelineMixin(PipelineMixin): + + def get_transformations(self): + log_insertions = self.ipu_config.log_insertions + layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + last_ipu = len(self.ipu_config.layers_per_ipu) - 1 + transformations = [ + AddPoptorchBlock("Embedding", 0, "bert.embeddings", log_insertions=log_insertions), + OutlineAttribute("bert.embeddings.LayerNorm", "Embedding"), + AddPoptorchBlocksInSeries( + "Encoder", layer_ipu, r"bert.encoder.layer.[0-9]+", log_insertions=log_insertions + ), + # Only one of the following AddPoptorchBlock, will actually add a block. + AddPoptorchBlock("Classifier Output", last_ipu, "classifier", log_insertions=log_insertions), + AddPoptorchBlock("QA Outputs", last_ipu, "qa_outputs", log_insertions=log_insertions), + ] + if self.ipu_config.recompute_checkpoint_every_layer: + transformations.append( + RecomputationCheckpoint( + "bert.encoder.layer.[0-9]+", to_exclude=f"bert.encoder.layer.{self.config.num_hidden_layers - 1}" + ) + ) + if self.ipu_config.embedding_serialization_factor > 1: + transformations.append(VocabEmbeddingToSerializedEmbedding()) + return transformations + + def parallelize(self): + """ + Transform the model to run in an IPU pipeline. + - Adds pipeline stages to the model + - Replaces self-attention layers with fused-qkv self-attention layers + - (If enabled) Replaces the word embedding with a SerializedEmbedding + - Adds recomputation checkpoints + """ + super().parallelize() + traced = symbolic_trace_pipelined_model(self) + transformations = self.get_transformations() + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) + composition = compose(*transformations) + non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations(self.ipu_config.optimization_level) + traced = composition(traced) + traced = non_reversible_composition(traced) + return traced -_NON_REVERSIBLE_TRANSFORMATIONS = [ - ClipValuesSymmetric(1e4, exclude_targets=["view"]), - ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), -] + def deparallelize(self): + """ + Undo the changes to the model done by `parallelize`. + You should call this before doing `save_pretrained` so that the `model.state_dict` is + compatible with the original model. + """ + super().deparallelize() + transformations = self.get_transformations() + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) + composition = compose(*transformations) + self = composition(self, reverse=True) + return self @register(BertForPreTraining) -class PipelinedBertForPreTraining(BertForPreTraining, PipelineMixin): +class PipelinedBertForPreTraining(BertForPreTraining, BertPipelineMixin): """ BertForPretraining transformed to run in an IPU pipeline. @@ -110,37 +156,6 @@ def get_transformations(self): ] return transformations - def parallelize(self): - """ - Transform the model to run in an IPU pipeline. - - Adds pipeline stages to the model - - Replaces self-attention layers with fused-qkv self-attention layers - - (If enabled) Replaces the word embedding projection with a SerializedLinear layer - - Adds recomputation checkpoints - """ - super().parallelize() - traced = symbolic_trace_pipelined_model(self) - transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS - composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) - traced = composition(traced) - traced = non_reversible_composition(traced) - return traced - - def deparallelize(self): - """ - Undo the changes to the model done by `parallelize`. - You should call this before doing `save_pretrained` so that the `model.state_dict` is - compatible with the original model. - """ - super().deparallelize() - transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS - composition = compose(*transformations) - self = composition(self, reverse=True) - return self - def _init_weights(self, module): """Initialize the weights""" @@ -227,7 +242,7 @@ def forward( @register(BertForMaskedLM) -class PipelinedBertForMaskedLM(BertForMaskedLM, PipelineMixin): +class PipelinedBertForMaskedLM(BertForMaskedLM, BertPipelineMixin): """ BertForMaskedLM transformed to run in an IPU pipeline. @@ -271,37 +286,6 @@ def get_transformations(self): ] return transformations - def parallelize(self): - """ - Transform the model to run in an IPU pipeline. - - Adds pipeline stages to the model - - Replaces self-attention layers with fused-qkv self-attention layers - - (If enabled) Replaces the word embedding projection with a SerializedLinear layer - - Adds recomputation checkpoints - """ - super().parallelize() - traced = symbolic_trace_pipelined_model(self) - transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS - composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) - traced = composition(traced) - traced = non_reversible_composition(traced) - return traced - - def deparallelize(self): - """ - Undo the changes to the model done by `parallelize`. - You should call this before doing `save_pretrained` so that the `model.state_dict` is - compatible with the original model. - """ - super().deparallelize() - transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS - composition = compose(*transformations) - self = composition(self, reverse=True) - return self - def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -376,63 +360,6 @@ def forward( ) -class BertPipelineMixin(PipelineMixin): - def get_transformations(self): - log_insertions = self.ipu_config.log_insertions - layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - last_ipu = len(self.ipu_config.layers_per_ipu) - 1 - transformations = [ - AddPoptorchBlock("Embedding", 0, "bert.embeddings", log_insertions=log_insertions), - OutlineAttribute("bert.embeddings.LayerNorm", "Embedding"), - AddPoptorchBlocksInSeries( - "Encoder", layer_ipu, r"bert.encoder.layer.[0-9]+", log_insertions=log_insertions - ), - # Only one of the following AddPoptorchBlock, will actually add a block. - AddPoptorchBlock("Classifier Output", last_ipu, "classifier", log_insertions=log_insertions), - AddPoptorchBlock("QA Outputs", last_ipu, "qa_outputs", log_insertions=log_insertions), - ] - if self.ipu_config.recompute_checkpoint_every_layer: - transformations.append( - RecomputationCheckpoint( - "bert.encoder.layer.[0-9]+", to_exclude=f"bert.encoder.layer.{self.config.num_hidden_layers - 1}" - ) - ) - if self.ipu_config.embedding_serialization_factor > 1: - transformations.append(VocabEmbeddingToSerializedEmbedding()) - return transformations - - def parallelize(self): - """ - Transform the model to run in an IPU pipeline. - - Adds pipeline stages to the model - - Replaces self-attention layers with fused-qkv self-attention layers - - (If enabled) Replaces the word embedding with a SerializedEmbedding - - Adds recomputation checkpoints - """ - super().parallelize() - traced = symbolic_trace_pipelined_model(self) - transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS - composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) - traced = composition(traced) - traced = non_reversible_composition(traced) - return traced - - def deparallelize(self): - """ - Undo the changes to the model done by `parallelize`. - You should call this before doing `save_pretrained` so that the `model.state_dict` is - compatible with the original model. - """ - super().deparallelize() - transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS - composition = compose(*transformations) - self = composition(self, reverse=True) - return self - - @register(BertForSequenceClassification) class PipelinedBertForSequenceClassification(BertForSequenceClassification, BertPipelineMixin): """ diff --git a/optimum/graphcore/models/deberta/modeling_deberta.py b/optimum/graphcore/models/deberta/modeling_deberta.py index ea2d32404..8be9ea934 100644 --- a/optimum/graphcore/models/deberta/modeling_deberta.py +++ b/optimum/graphcore/models/deberta/modeling_deberta.py @@ -118,6 +118,9 @@ def _get_rel_embedding(self): return self.rel_embeddings.weight + 0.0 if self.relative_attention else None +gather_last_dim = FastGatherLastDim() + + class IPUDisentangledSelfAttention(DisentangledSelfAttention): """ Disentangled self-attention module @@ -132,7 +135,8 @@ class IPUDisentangledSelfAttention(DisentangledSelfAttention): def __init__(self, config): super().__init__(config) self.xsoftmax = XSoftmax(-1) - self.gather_last_dim = FastGatherLastDim() + # self.gather_last_dim = FastGatherLastDim() + self.gather_last_dim = gather_last_dim def forward( self, diff --git a/optimum/graphcore/models/gpt2/modeling_gpt2.py b/optimum/graphcore/models/gpt2/modeling_gpt2.py index 52dc67921..a3f571764 100644 --- a/optimum/graphcore/models/gpt2/modeling_gpt2.py +++ b/optimum/graphcore/models/gpt2/modeling_gpt2.py @@ -18,43 +18,27 @@ import torch import torch.nn as nn -import poptorch from transformers import GPT2ForSequenceClassification, GPT2ForTokenClassification, GPT2LMHeadModel from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast -from transformers.models.gpt2.modeling_gpt2 import GPT2Attention -from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, ReversibleTransformation, compose +from ....fx.optimization import ReversibleTransformation, compose from ....utils import logging -from ...fx.transformations import ( +from ...fx import ( AddPoptorchBlock, AddPoptorchBlocksInSeries, - ClipValues, - ClipValuesSymmetric, LinearToSerializedLinear, OutlineAttribute, RecomputationCheckpoint, TieWeights, - TupleOutput, VocabEmbeddingToSerializedEmbedding, + symbolic_trace_pipelined_model, + DEFAULT_TRANSFORMATION_MANAGER, ) -from ...fx.utils import symbolic_trace_pipelined_model from ...modeling_utils import PipelineMixin, get_layer_ipu, register -from .optimized_gpt2_attn import OptimizedGPT2Attention logger = logging.get_logger(__name__) -_OPTIMIZATION_TRANSFORMATIONS = [ - ChangeTrueDivToMulByInverse(), - MergeLinears(), - # FuseBiasInLinear(), -] - -_NON_REVERSIBLE_TRANSFORMATIONS = [ - ClipValuesSymmetric(1e4, exclude_targets=["view"]), - ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), -] - class GPT2PipelineMixin(PipelineMixin): @property @@ -114,9 +98,9 @@ def parallelize(self): self.resize_vocab(False) traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations(self.ipu_config.optimization_level) traced = composition(traced) traced = non_reversible_composition(traced) return traced @@ -129,8 +113,8 @@ def deparallelize(self): """ super().deparallelize() transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS transformations = [t for t in transformations if isinstance(t, ReversibleTransformation)] + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) composition = compose(*transformations) self = composition(self, reverse=True) if self.ipu_config.embedding_serialization_factor > 1: @@ -165,41 +149,6 @@ def get_transformations(self): return transformations - def parallelize(self): - """ - Transform the Roberta model body to run in an IPU pipeline. - - Adds pipeline stages to the model - - (If enabled) Replaces the word embedding with a SerializedEmbedding - - Adds recomputation checkpoints - """ - PipelineMixin.parallelize(self) - if self.ipu_config.embedding_serialization_factor > 1: - self.resize_vocab(False) - traced = symbolic_trace_pipelined_model(self) - transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS - composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) - traced = composition(traced) - traced = non_reversible_composition(traced) - return traced - - def deparallelize(self): - """ - Undo the changes to the model done by `parallelize`. - You should call this before doing `save_pretrained` so that the `model.state_dict` is - fully compatible with the original model. - """ - PipelineMixin.deparallelize(self) - transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS - transformations = [t for t in transformations if isinstance(t, ReversibleTransformation)] - composition = compose(*transformations) - self = composition(self, reverse=True) - if self.ipu_config.embedding_serialization_factor > 1: - self.resize_vocab(True) - return self - def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/optimum/graphcore/models/hubert/ipu_layer_drop.py b/optimum/graphcore/models/hubert/ipu_layer_drop.py index 92bf1e889..494abd5bd 100644 --- a/optimum/graphcore/models/hubert/ipu_layer_drop.py +++ b/optimum/graphcore/models/hubert/ipu_layer_drop.py @@ -62,7 +62,7 @@ def forward( # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # Modify LayerDrop so it can be statically compiled without eager mode if self.config.layerdrop > 0.0: - dropout_probability = torch.rand(tuple(), device=hidden_states.device) + dropout_probability = torch.rand((), device=hidden_states.device) skip_the_layer = ( torch.tensor(self.training, device=hidden_states.device) & (dropout_probability < self.config.layerdrop) @@ -125,7 +125,7 @@ def forward( # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # Modify LayerDrop so it can be statically compiled without eager mode if self.config.layerdrop > 0.0: - dropout_probability = torch.rand(tuple(), device=hidden_states.device) + dropout_probability = torch.rand((), device=hidden_states.device) skip_the_layer = ( torch.tensor(self.training, device=hidden_states.device) & (dropout_probability < self.config.layerdrop) diff --git a/optimum/graphcore/models/hubert/modeling_hubert.py b/optimum/graphcore/models/hubert/modeling_hubert.py index f5a8749d1..5b97fbb69 100644 --- a/optimum/graphcore/models/hubert/modeling_hubert.py +++ b/optimum/graphcore/models/hubert/modeling_hubert.py @@ -11,41 +11,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch - -import poptorch from transformers import HubertForSequenceClassification -from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ....fx.optimization import MergeLinears, compose from ....utils import logging -from ...fx.transformations import ( +from ...fx import ( AddPoptorchBlock, AddPoptorchBlocksInSeries, - ClipValues, - ClipValuesSymmetric, RecomputationCheckpoint, - TupleOutput, + symbolic_trace_pipelined_model, + DEFAULT_TRANSFORMATION_MANAGER, ) -from ...fx.utils import symbolic_trace_pipelined_model from ...modeling_utils import PipelineMixin, get_layer_ipu, register logger = logging.get_logger(__name__) -_OPTIMIZATION_TRANSFORMATIONS = [ - ChangeTrueDivToMulByInverse(), - MergeLinears(), - # FuseBiasInLinear(), -] - -_NON_REVERSIBLE_TRANSFORMATIONS = [ - ClipValuesSymmetric(1e4, exclude_targets=["view"]), - ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), -] +TRANSFORMATION_MANAGER = DEFAULT_TRANSFORMATION_MANAGER.without(MergeLinears()) @register(HubertForSequenceClassification) class PipelinedHubertForSequenceClassification(HubertForSequenceClassification, PipelineMixin): + def change_hubert_encoder_class(self, restore: bool): + """Changes the encoder class to update its forward pass so that it uses our custom version. + Args: + restore: whether to restore the encoder to its original version or not. + """ + from .ipu_layer_drop import IPUHubertEncoder, IPUHubertEncoderStableLayerNorm + from transformers.models.hubert.modeling_hubert import HubertEncoder, HubertEncoderStableLayerNorm + + if self.config.do_stable_layer_norm: + new_cls = HubertEncoderStableLayerNorm if restore else IPUHubertEncoderStableLayerNorm + else: + new_cls = HubertEncoder if restore else IPUHubertEncoder + self.hubert.encoder.__class__ = new_cls + def get_transformations(self): log_insertions = self.ipu_config.log_insertions layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) @@ -70,11 +70,12 @@ def get_transformations(self): def parallelize(self): super().parallelize() + self.change_hubert_encoder_class(False) traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + non_reversible_composition = TRANSFORMATION_MANAGER.compose_non_reversible_transformations(self.ipu_config.optimization_level) traced = composition(traced) traced = non_reversible_composition(traced) return traced @@ -82,7 +83,7 @@ def parallelize(self): def deparallelize(self): super().deparallelize() transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) composition = compose(*transformations) self = composition(self, reverse=True) return self diff --git a/optimum/graphcore/models/t5/modeling_t5.py b/optimum/graphcore/models/t5/modeling_t5.py index 18ce59c92..689e6ae91 100644 --- a/optimum/graphcore/models/t5/modeling_t5.py +++ b/optimum/graphcore/models/t5/modeling_t5.py @@ -16,44 +16,29 @@ from typing import Optional, Tuple, Union import torch -import torch.nn as nn -import poptorch from optimum.utils import logging from transformers import T5ForConditionalGeneration from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput -from transformers.models.t5.modeling_t5 import __HEAD_MASK_WARNING_MSG, T5LayerNorm +from transformers.models.t5.modeling_t5 import __HEAD_MASK_WARNING_MSG -from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, ReversibleTransformation, compose -from ...fx.transformations import ( +from ....fx.optimization import ReversibleTransformation, compose +from ...fx import ( AddPoptorchBlock, AddPoptorchBlocksInSeries, - ClipValues, - ClipValuesSymmetric, LinearToSerializedLinear, RecomputationCheckpoint, ShareEmbeddingComputation, TieWeights, - TupleOutput, + symbolic_trace_pipelined_model, + DEFAULT_TRANSFORMATION_MANAGER, ) -from ...fx.utils import symbolic_trace_pipelined_model from ...generation_utils import IPUGenerationMixin -from ...modeling_utils import GenerationMethodsMixin, PipelineMixin, SharedEmbedding, get_layer_ipu, register +from ...modeling_utils import GenerationMethodsMixin, PipelineMixin, get_layer_ipu, register logger = logging.get_logger(__name__) -_OPTIMIZATION_TRANSFORMATIONS = [ - ChangeTrueDivToMulByInverse(), - MergeLinears(), - # FuseBiasInLinear(), -] - -_NON_REVERSIBLE_TRANSFORMATIONS = [ - ClipValuesSymmetric(1e4, exclude_targets=["view"]), - ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), -] - @register(T5ForConditionalGeneration) class PipelinedT5ForConditionalGeneration( @@ -151,9 +136,9 @@ def parallelize(self): # mod.forward = poptorch.autocast(enabled=True)(mod.forward) traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations(self.ipu_config.optimization_level) traced = composition(traced) traced = non_reversible_composition(traced) return traced @@ -167,8 +152,8 @@ def deparallelize(self): # T5ForConditionalGeneration has a deparallelize method, so make sure that the PipelineMixin one is used here. PipelineMixin.deparallelize(self) transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS transformations = [t for t in transformations if isinstance(t, ReversibleTransformation)] + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) composition = compose(*transformations) self = composition(self, reverse=True) return self diff --git a/optimum/graphcore/models/vit/modeling_vit.py b/optimum/graphcore/models/vit/modeling_vit.py index 45e8d0a40..aff9c1c8d 100644 --- a/optimum/graphcore/models/vit/modeling_vit.py +++ b/optimum/graphcore/models/vit/modeling_vit.py @@ -11,37 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch - import transformers -from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ....fx.optimization import compose from ....utils import logging -from ...fx.transformations import ( +from ...fx import ( AddPoptorchBlock, AddPoptorchBlocksInSeries, - ClipValues, - ClipValuesSymmetric, RecomputationCheckpoint, - TupleOutput, + symbolic_trace_pipelined_model, + DEFAULT_TRANSFORMATION_MANAGER, ) -from ...fx.utils import symbolic_trace_pipelined_model from ...modeling_utils import PipelineMixin, get_layer_ipu, register logger = logging.get_logger(__name__) -_OPTIMIZATION_TRANSFORMATIONS = [ - ChangeTrueDivToMulByInverse(), - MergeLinears(), - # FuseBiasInLinear(), -] - -_NON_REVERSIBLE_TRANSFORMATIONS = [ - ClipValuesSymmetric(1e4, exclude_targets=["view"]), - ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), -] - @register(transformers.ViTForImageClassification) class PipelinedViTForImageClassification(transformers.ViTForImageClassification, PipelineMixin): @@ -67,9 +52,9 @@ def parallelize(self): super().parallelize() traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations(self.ipu_config.optimization_level) traced = composition(traced) traced = non_reversible_composition(traced) return traced @@ -77,7 +62,7 @@ def parallelize(self): def deparallelize(self): super().deparallelize() transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) composition = compose(*transformations) self = composition(self, reverse=True) return self From 0fe6da791a375737705c25a01aa4ed6fe56a5f95 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 20 Oct 2022 18:35:52 +0200 Subject: [PATCH 23/33] Make style --- optimum/graphcore/fx/utils.py | 4 +- .../graphcore/models/bart/modeling_bart.py | 9 +-- .../graphcore/models/bert/modeling_bert.py | 17 +++-- .../models/convnext/modeling_convnext.py | 39 ++++------- .../models/deberta/modeling_deberta.py | 34 ++++------ .../models/distilbert/modeling_distilbert.py | 33 ++++------ .../graphcore/models/gpt2/modeling_gpt2.py | 14 ++-- .../models/hubert/modeling_hubert.py | 9 ++- .../models/lxmert/modeling_lxmert.py | 34 ++++------ .../models/roberta/modeling_roberta.py | 66 ++++--------------- optimum/graphcore/models/t5/modeling_t5.py | 14 ++-- optimum/graphcore/models/vit/modeling_vit.py | 14 ++-- .../models/wav2vec2/modeling_wav2vec2.py | 33 ++++------ 13 files changed, 133 insertions(+), 187 deletions(-) diff --git a/optimum/graphcore/fx/utils.py b/optimum/graphcore/fx/utils.py index d33a132ed..43f61b067 100644 --- a/optimum/graphcore/fx/utils.py +++ b/optimum/graphcore/fx/utils.py @@ -104,7 +104,9 @@ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, pr ("call_method", "to"), ("call_function", torch.full), ] - torch_methods_to_patched_version = {orig: wrapped for (orig, wrapped) in self.patched_torch_methods.values()} + torch_methods_to_patched_version = { + orig: wrapped for (orig, wrapped) in self.patched_torch_methods.values() + } for (k, t) in node_types_to_inspect: if kind == k and target == torch_methods_to_patched_version.get(t, t): if float32_dtype_in_args: diff --git a/optimum/graphcore/models/bart/modeling_bart.py b/optimum/graphcore/models/bart/modeling_bart.py index 8d632610e..7b9148090 100644 --- a/optimum/graphcore/models/bart/modeling_bart.py +++ b/optimum/graphcore/models/bart/modeling_bart.py @@ -24,9 +24,9 @@ from ....fx.optimization import ReversibleTransformation, compose from ...fx import ( + DEFAULT_TRANSFORMATION_MANAGER, AddPoptorchBlock, AddPoptorchBlocksInSeries, - ClipValues, ClipValuesSymmetric, LinearToSerializedLinear, RecomputationCheckpoint, @@ -34,7 +34,6 @@ TieWeights, VocabEmbeddingToSerializedEmbedding, symbolic_trace_pipelined_model, - DEFAULT_TRANSFORMATION_MANAGER, ) from ...generation_utils import IPUGenerationMixin from ...modeling_utils import GenerationMethodsMixin, PipelineMixin, get_layer_ipu, register @@ -45,7 +44,7 @@ FLOAT16_LIMIT = 1e4 TRANSFORMATION_MANAGER = DEFAULT_TRANSFORMATION_MANAGER.without(ClipValuesSymmetric(1e4, exclude_targets=("view",))) -TRANSFORMATION_MANAGER.register(1, ClipValuesSymmetric(10000, exclude_targets=("view",)) +TRANSFORMATION_MANAGER.register(1, ClipValuesSymmetric(10000, exclude_targets=("view",))) def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): @@ -220,7 +219,9 @@ def parallelize(self): transformations = self.get_transformations() transformations += TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) composition = compose(*transformations) - non_reversible_composition = TRANSFORMATION_MANAGER.compose_non_reversible_transformations(self.ipu_config.optimization_level) + non_reversible_composition = TRANSFORMATION_MANAGER.compose_non_reversible_transformations( + self.ipu_config.optimization_level + ) traced = composition(traced) traced = non_reversible_composition(traced) return traced diff --git a/optimum/graphcore/models/bert/modeling_bert.py b/optimum/graphcore/models/bert/modeling_bert.py index 65816a103..b29806306 100644 --- a/optimum/graphcore/models/bert/modeling_bert.py +++ b/optimum/graphcore/models/bert/modeling_bert.py @@ -34,17 +34,15 @@ from ....fx.optimization import compose from ...fx import ( + DEFAULT_TRANSFORMATION_MANAGER, AddPoptorchBlock, AddPoptorchBlocksInSeries, - ClipValues, - ClipValuesSymmetric, LinearToSerializedLinear, OutlineAttribute, RecomputationCheckpoint, TieWeights, VocabEmbeddingToSerializedEmbedding, symbolic_trace_pipelined_model, - DEFAULT_TRANSFORMATION_MANAGER, ) from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register @@ -53,7 +51,6 @@ class BertPipelineMixin(PipelineMixin): - def get_transformations(self): log_insertions = self.ipu_config.log_insertions layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) @@ -89,9 +86,13 @@ def parallelize(self): super().parallelize() traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) - non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations(self.ipu_config.optimization_level) + non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations( + self.ipu_config.optimization_level + ) traced = composition(traced) traced = non_reversible_composition(traced) return traced @@ -104,7 +105,9 @@ def deparallelize(self): """ super().deparallelize() transformations = self.get_transformations() - transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) self = composition(self, reverse=True) return self diff --git a/optimum/graphcore/models/convnext/modeling_convnext.py b/optimum/graphcore/models/convnext/modeling_convnext.py index fc61750a8..19be99e97 100644 --- a/optimum/graphcore/models/convnext/modeling_convnext.py +++ b/optimum/graphcore/models/convnext/modeling_convnext.py @@ -14,40 +14,23 @@ import torch from torch import nn -from transformers.models.convnext.modeling_convnext import ( - ConvNextForImageClassification, - ConvNextLayer, - ConvNextLayerNorm, -) +from transformers.models.convnext.modeling_convnext import ConvNextForImageClassification, ConvNextLayerNorm -from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ....fx.optimization import compose from ....utils import logging -from ...fx.transformations import ( +from ...fx import ( + DEFAULT_TRANSFORMATION_MANAGER, AddPoptorchBlock, AddPoptorchBlocksInSeries, - ClipValues, - ClipValuesSymmetric, RecomputationCheckpoint, - TupleOutput, + symbolic_trace_pipelined_model, ) -from ...fx.utils import symbolic_trace_pipelined_model from ...modeling_utils import PipelineMixin, get_layer_ipu, register from .optimized_convnextlayer import OptimizedConvNextLayer logger = logging.get_logger(__name__) -_OPTIMIZATION_TRANSFORMATIONS = [ - ChangeTrueDivToMulByInverse(), - MergeLinears(), - # FuseBiasInLinear(), -] - -_NON_REVERSIBLE_TRANSFORMATIONS = [ - ClipValuesSymmetric(1e4, exclude_targets=["view"]), - ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), -] - class IPUConvNextLayerNorm(nn.Module): r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. @@ -117,9 +100,13 @@ def parallelize(self): traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations( + self.ipu_config.optimization_level + ) traced = composition(traced) traced = non_reversible_composition(traced) return traced @@ -133,7 +120,9 @@ def deparallelize(self): mod.__class__ = ConvNextLayerNorm transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) self = composition(self, reverse=True) return self diff --git a/optimum/graphcore/models/deberta/modeling_deberta.py b/optimum/graphcore/models/deberta/modeling_deberta.py index 8be9ea934..d6c9e5df1 100644 --- a/optimum/graphcore/models/deberta/modeling_deberta.py +++ b/optimum/graphcore/models/deberta/modeling_deberta.py @@ -20,7 +20,6 @@ import torch.nn as nn import poptorch -import transformers from transformers import ( DebertaForMaskedLM, DebertaForQuestionAnswering, @@ -36,36 +35,23 @@ ) from transformers.utils.fx import _gen_constructor_wrapper -from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ....fx.optimization import MergeLinears, compose from ....utils import logging -from ...fx.transformations import ( +from ...fx import ( + DEFAULT_TRANSFORMATION_MANAGER, AddPoptorchBlock, AddPoptorchBlocksInSeries, AutoCast, - ClipValues, - ClipValuesSymmetric, OutlineAttribute, RecomputationCheckpoint, - TupleOutput, VocabEmbeddingToSerializedEmbedding, + symbolic_trace_pipelined_model, ) -from ...fx.utils import symbolic_trace_pipelined_model from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register logger = logging.get_logger(__name__) -_OPTIMIZATION_TRANSFORMATIONS = [ - ChangeTrueDivToMulByInverse(), - MergeLinears(), - # FuseBiasInLinear(), -] - -_NON_REVERSIBLE_TRANSFORMATIONS = [ - ClipValuesSymmetric(1e4, exclude_targets=["view"]), - ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), -] - class FastGatherLastDim(nn.Module): """ @@ -349,9 +335,13 @@ def parallelize(self): traced = symbolic_trace_pipelined_model(self) torch.nn.functional.one_hot = orig transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations( + self.ipu_config.optimization_level + ) traced = composition(traced) traced = non_reversible_composition(traced) return traced @@ -365,7 +355,9 @@ def deparallelize(self): super().deparallelize() self.change_modules_for_ipu(True) transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) self = composition(self, reverse=True) return self diff --git a/optimum/graphcore/models/distilbert/modeling_distilbert.py b/optimum/graphcore/models/distilbert/modeling_distilbert.py index e64a27a5b..03ddb53e3 100644 --- a/optimum/graphcore/models/distilbert/modeling_distilbert.py +++ b/optimum/graphcore/models/distilbert/modeling_distilbert.py @@ -31,36 +31,23 @@ from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention -from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose -from ...fx.transformations import ( +from ....fx.optimization import compose +from ...fx import ( + DEFAULT_TRANSFORMATION_MANAGER, AddPoptorchBlock, AddPoptorchBlocksInSeries, - ClipValues, - ClipValuesSymmetric, LinearToSerializedLinear, OutlineAttribute, RecomputationCheckpoint, TieWeights, - TupleOutput, VocabEmbeddingToSerializedEmbedding, + symbolic_trace_pipelined_model, ) -from ...fx.utils import symbolic_trace_pipelined_model from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register logger = logging.get_logger(__name__) -_OPTIMIZATION_TRANSFORMATIONS = [ - ChangeTrueDivToMulByInverse(), - MergeLinears(), - # FuseBiasInLinear(), -] - -_NON_REVERSIBLE_TRANSFORMATIONS = [ - ClipValuesSymmetric(1e4, exclude_targets=["view"]), - ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), -] - # TODO: should we make a fx transformation for this? class IPUMultiHeadSelfAttention(MultiHeadSelfAttention): @@ -173,9 +160,13 @@ def parallelize(self): mod.__class__ = IPUMultiHeadSelfAttention traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations( + self.ipu_config.optimization_level + ) traced = composition(traced) traced = non_reversible_composition(traced) return traced @@ -191,7 +182,9 @@ def deparallelize(self): if isinstance(mod, IPUMultiHeadSelfAttention): mod.__class__ = MultiHeadSelfAttention transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) self = composition(self, reverse=True) return self diff --git a/optimum/graphcore/models/gpt2/modeling_gpt2.py b/optimum/graphcore/models/gpt2/modeling_gpt2.py index a3f571764..4fea246bf 100644 --- a/optimum/graphcore/models/gpt2/modeling_gpt2.py +++ b/optimum/graphcore/models/gpt2/modeling_gpt2.py @@ -24,6 +24,7 @@ from ....fx.optimization import ReversibleTransformation, compose from ....utils import logging from ...fx import ( + DEFAULT_TRANSFORMATION_MANAGER, AddPoptorchBlock, AddPoptorchBlocksInSeries, LinearToSerializedLinear, @@ -32,7 +33,6 @@ TieWeights, VocabEmbeddingToSerializedEmbedding, symbolic_trace_pipelined_model, - DEFAULT_TRANSFORMATION_MANAGER, ) from ...modeling_utils import PipelineMixin, get_layer_ipu, register @@ -98,9 +98,13 @@ def parallelize(self): self.resize_vocab(False) traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) - non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations(self.ipu_config.optimization_level) + non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations( + self.ipu_config.optimization_level + ) traced = composition(traced) traced = non_reversible_composition(traced) return traced @@ -114,7 +118,9 @@ def deparallelize(self): super().deparallelize() transformations = self.get_transformations() transformations = [t for t in transformations if isinstance(t, ReversibleTransformation)] - transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) self = composition(self, reverse=True) if self.ipu_config.embedding_serialization_factor > 1: diff --git a/optimum/graphcore/models/hubert/modeling_hubert.py b/optimum/graphcore/models/hubert/modeling_hubert.py index 5b97fbb69..a36a31893 100644 --- a/optimum/graphcore/models/hubert/modeling_hubert.py +++ b/optimum/graphcore/models/hubert/modeling_hubert.py @@ -16,11 +16,11 @@ from ....fx.optimization import MergeLinears, compose from ....utils import logging from ...fx import ( + DEFAULT_TRANSFORMATION_MANAGER, AddPoptorchBlock, AddPoptorchBlocksInSeries, RecomputationCheckpoint, symbolic_trace_pipelined_model, - DEFAULT_TRANSFORMATION_MANAGER, ) from ...modeling_utils import PipelineMixin, get_layer_ipu, register @@ -37,9 +37,10 @@ def change_hubert_encoder_class(self, restore: bool): Args: restore: whether to restore the encoder to its original version or not. """ - from .ipu_layer_drop import IPUHubertEncoder, IPUHubertEncoderStableLayerNorm from transformers.models.hubert.modeling_hubert import HubertEncoder, HubertEncoderStableLayerNorm + from .ipu_layer_drop import IPUHubertEncoder, IPUHubertEncoderStableLayerNorm + if self.config.do_stable_layer_norm: new_cls = HubertEncoderStableLayerNorm if restore else IPUHubertEncoderStableLayerNorm else: @@ -75,7 +76,9 @@ def parallelize(self): transformations = self.get_transformations() transformations += TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) composition = compose(*transformations) - non_reversible_composition = TRANSFORMATION_MANAGER.compose_non_reversible_transformations(self.ipu_config.optimization_level) + non_reversible_composition = TRANSFORMATION_MANAGER.compose_non_reversible_transformations( + self.ipu_config.optimization_level + ) traced = composition(traced) traced = non_reversible_composition(traced) return traced diff --git a/optimum/graphcore/models/lxmert/modeling_lxmert.py b/optimum/graphcore/models/lxmert/modeling_lxmert.py index c0ac46a01..9d476198a 100644 --- a/optimum/graphcore/models/lxmert/modeling_lxmert.py +++ b/optimum/graphcore/models/lxmert/modeling_lxmert.py @@ -18,33 +18,21 @@ from transformers.models.lxmert.modeling_lxmert import LxmertForQuestionAnswering, LxmertForQuestionAnsweringOutput -from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ....fx.optimization import MergeLinears, compose from ....utils import logging -from ...fx.transformations import ( +from ...fx import ( + DEFAULT_TRANSFORMATION_MANAGER, AddPoptorchBlock, AddPoptorchBlocksInSeries, - ClipValues, - ClipValuesSymmetric, RecomputationCheckpoint, - TupleOutput, + symbolic_trace_pipelined_model, ) -from ...fx.utils import symbolic_trace_pipelined_model from ...modeling_utils import PipelineMixin, get_layer_ipu, register -logger = logging.get_logger(__name__) - -_OPTIMIZATION_TRANSFORMATIONS = [ - ChangeTrueDivToMulByInverse(), - # TODO: Not working for now. - # MergeLinears(), - # FuseBiasInLinear(), -] +TRANSFORMATION_MANAGER = DEFAULT_TRANSFORMATION_MANAGER.without(MergeLinears()) -_NON_REVERSIBLE_TRANSFORMATIONS = [ - ClipValuesSymmetric(1e4, exclude_targets=["view"]), - ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), -] +logger = logging.get_logger(__name__) @register(LxmertForQuestionAnswering) @@ -52,7 +40,7 @@ class PipelinedLxmertForQuestionAnswering(LxmertForQuestionAnswering, PipelineMi def get_transformations(self): log_insertions = self.ipu_config.log_insertions layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - # TODO: remove this line after testing. + # TODO: remove this line after testing. layer_ipu = get_layer_ipu([0, 7, 7, 5]) language_layers_ipus = layer_ipu[: self.config.l_layers] visual_layers_ipus = layer_ipu[self.config.l_layers : self.config.l_layers + self.config.r_layers] @@ -94,9 +82,11 @@ def parallelize(self): super().parallelize() traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + non_reversible_composition = TRANSFORMATION_MANAGER.compose_non_reversible_transformations( + self.ipu_config.optimization_level + ) traced = composition(traced) traced = non_reversible_composition(traced) return traced @@ -104,7 +94,7 @@ def parallelize(self): def deparallelize(self): super().deparallelize() transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) composition = compose(*transformations) self = composition(self, reverse=True) return self diff --git a/optimum/graphcore/models/roberta/modeling_roberta.py b/optimum/graphcore/models/roberta/modeling_roberta.py index 3183b6ac4..562b49a28 100644 --- a/optimum/graphcore/models/roberta/modeling_roberta.py +++ b/optimum/graphcore/models/roberta/modeling_roberta.py @@ -15,7 +15,6 @@ from typing import Optional, Tuple, Union import torch -import torch.nn as nn from torch.nn import CrossEntropyLoss import poptorch @@ -28,37 +27,24 @@ ) from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput -from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ....fx.optimization import compose from ....utils import logging -from ...fx.transformations import ( +from ...fx import ( + DEFAULT_TRANSFORMATION_MANAGER, AddPoptorchBlock, AddPoptorchBlocksInSeries, - ClipValues, - ClipValuesSymmetric, LinearToSerializedLinear, OutlineAttribute, RecomputationCheckpoint, TieWeights, - TupleOutput, VocabEmbeddingToSerializedEmbedding, + symbolic_trace_pipelined_model, ) -from ...fx.utils import symbolic_trace_pipelined_model from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register logger = logging.get_logger(__name__) -_OPTIMIZATION_TRANSFORMATIONS = [ - ChangeTrueDivToMulByInverse(), - MergeLinears(), - # FuseBiasInLinear(), -] - -_NON_REVERSIBLE_TRANSFORMATIONS = [ - ClipValuesSymmetric(1e4, exclude_targets=["view"]), - ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), -] - class RobertaPipelineMixin(PipelineMixin): def get_transformations(self): @@ -96,9 +82,13 @@ def parallelize(self): super().parallelize() traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations( + self.ipu_config.optimization_level + ) traced = composition(traced) traced = non_reversible_composition(traced) return traced @@ -111,14 +101,16 @@ def deparallelize(self): """ super().deparallelize() transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) self = composition(self, reverse=True) return self @register(RobertaForMaskedLM) -class PipelinedRobertaForMaskedLM(RobertaForMaskedLM, PipelineMixin): +class PipelinedRobertaForMaskedLM(RobertaForMaskedLM, RobertaPipelineMixin): """ RobertaForMaskedLM transformed to run in an IPU pipeline. @@ -157,36 +149,6 @@ def get_transformations(self): ] return transformations - def parallelize(self): - """ - Transform the model to run in an IPU pipeline. - - Adds pipeline stages to the model - - (If enabled) Replaces the word embedding projection with a SerializedLinear layer - - Adds recomputation checkpoints - """ - super().parallelize() - traced = symbolic_trace_pipelined_model(self) - transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS - composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) - traced = composition(traced) - traced = non_reversible_composition(traced) - return traced - - def deparallelize(self): - """ - Undo the changes to the model done by `parallelize`. - You should call this before doing `save_pretrained` so that the `model.state_dict` is - compatible with the original model. - """ - super().deparallelize() - transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS - composition = compose(*transformations) - self = composition(self, reverse=True) - return self - def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/optimum/graphcore/models/t5/modeling_t5.py b/optimum/graphcore/models/t5/modeling_t5.py index 689e6ae91..00653ebc1 100644 --- a/optimum/graphcore/models/t5/modeling_t5.py +++ b/optimum/graphcore/models/t5/modeling_t5.py @@ -24,6 +24,7 @@ from ....fx.optimization import ReversibleTransformation, compose from ...fx import ( + DEFAULT_TRANSFORMATION_MANAGER, AddPoptorchBlock, AddPoptorchBlocksInSeries, LinearToSerializedLinear, @@ -31,7 +32,6 @@ ShareEmbeddingComputation, TieWeights, symbolic_trace_pipelined_model, - DEFAULT_TRANSFORMATION_MANAGER, ) from ...generation_utils import IPUGenerationMixin from ...modeling_utils import GenerationMethodsMixin, PipelineMixin, get_layer_ipu, register @@ -136,9 +136,13 @@ def parallelize(self): # mod.forward = poptorch.autocast(enabled=True)(mod.forward) traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) - non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations(self.ipu_config.optimization_level) + non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations( + self.ipu_config.optimization_level + ) traced = composition(traced) traced = non_reversible_composition(traced) return traced @@ -153,7 +157,9 @@ def deparallelize(self): PipelineMixin.deparallelize(self) transformations = self.get_transformations() transformations = [t for t in transformations if isinstance(t, ReversibleTransformation)] - transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) self = composition(self, reverse=True) return self diff --git a/optimum/graphcore/models/vit/modeling_vit.py b/optimum/graphcore/models/vit/modeling_vit.py index aff9c1c8d..47047e78a 100644 --- a/optimum/graphcore/models/vit/modeling_vit.py +++ b/optimum/graphcore/models/vit/modeling_vit.py @@ -16,11 +16,11 @@ from ....fx.optimization import compose from ....utils import logging from ...fx import ( + DEFAULT_TRANSFORMATION_MANAGER, AddPoptorchBlock, AddPoptorchBlocksInSeries, RecomputationCheckpoint, symbolic_trace_pipelined_model, - DEFAULT_TRANSFORMATION_MANAGER, ) from ...modeling_utils import PipelineMixin, get_layer_ipu, register @@ -52,9 +52,13 @@ def parallelize(self): super().parallelize() traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) - non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations(self.ipu_config.optimization_level) + non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations( + self.ipu_config.optimization_level + ) traced = composition(traced) traced = non_reversible_composition(traced) return traced @@ -62,7 +66,9 @@ def parallelize(self): def deparallelize(self): super().deparallelize() transformations = self.get_transformations() - transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) self = composition(self, reverse=True) return self diff --git a/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py b/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py index 344225bc4..a62192d3f 100755 --- a/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py @@ -30,17 +30,15 @@ Wav2Vec2GumbelVectorQuantizer, ) -from ....fx.optimization import ChangeTrueDivToMulByInverse, MergeLinears, compose +from ....fx.optimization import compose from ....utils import logging -from ...fx.transformations import ( +from ...fx import ( + DEFAULT_TRANSFORMATION_MANAGER, AddPoptorchBlock, AddPoptorchBlocksInSeries, - ClipValues, - ClipValuesSymmetric, RecomputationCheckpoint, - TupleOutput, + symbolic_trace_pipelined_model, ) -from ...fx.utils import symbolic_trace_pipelined_model from ...modeling_utils import PipelineMixin, get_layer_ipu, register from .ipu_gumbel_vector_quantizer import IPUWav2Vec2GumbelVectorQuantizer from .ipu_layer_drop import IPUWav2Vec2Adapter, IPUWav2Vec2Encoder, IPUWav2Vec2EncoderStableLayerNorm @@ -48,17 +46,6 @@ logger = logging.get_logger(__name__) -_OPTIMIZATION_TRANSFORMATIONS = [ - ChangeTrueDivToMulByInverse(), - MergeLinears(), - # FuseBiasInLinear(), -] - -_NON_REVERSIBLE_TRANSFORMATIONS = [ - ClipValuesSymmetric(1e4, exclude_targets=["view"]), - ClipValues(1e-4, float("inf"), include_targets=[torch.nn.LayerNorm]), -] - class IPUWav2Vec2Model(Wav2Vec2Model): def _get_feature_vector_attention_mask( @@ -132,9 +119,13 @@ def parallelize(self): super().parallelize() traced = symbolic_trace_pipelined_model(self) transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) - non_reversible_composition = compose(*_NON_REVERSIBLE_TRANSFORMATIONS) + non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations( + self.ipu_config.optimization_level + ) traced = composition(traced) traced = non_reversible_composition(traced) return traced @@ -142,7 +133,9 @@ def parallelize(self): def deparallelize(self): super().deparallelize() transformations = self.get_transformations() - transformations += _OPTIMIZATION_TRANSFORMATIONS + transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( + self.ipu_config.optimization_level + ) composition = compose(*transformations) self = composition(self, reverse=True) return self From ba03938cfcc3065d90eed543827a18908944aa8f Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 20 Oct 2022 18:43:50 +0200 Subject: [PATCH 24/33] Change deberta --- optimum/graphcore/models/bart/modeling_bart.py | 2 +- optimum/graphcore/models/deberta/modeling_deberta.py | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/optimum/graphcore/models/bart/modeling_bart.py b/optimum/graphcore/models/bart/modeling_bart.py index 7b9148090..6cb6a3943 100644 --- a/optimum/graphcore/models/bart/modeling_bart.py +++ b/optimum/graphcore/models/bart/modeling_bart.py @@ -44,7 +44,7 @@ FLOAT16_LIMIT = 1e4 TRANSFORMATION_MANAGER = DEFAULT_TRANSFORMATION_MANAGER.without(ClipValuesSymmetric(1e4, exclude_targets=("view",))) -TRANSFORMATION_MANAGER.register(1, ClipValuesSymmetric(10000, exclude_targets=("view",))) +TRANSFORMATION_MANAGER.register((1, ClipValuesSymmetric(10000, exclude_targets=("view",)))) def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): diff --git a/optimum/graphcore/models/deberta/modeling_deberta.py b/optimum/graphcore/models/deberta/modeling_deberta.py index d6c9e5df1..a57025197 100644 --- a/optimum/graphcore/models/deberta/modeling_deberta.py +++ b/optimum/graphcore/models/deberta/modeling_deberta.py @@ -50,6 +50,8 @@ from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register +TRANSFORMATION_MANAGER = DEFAULT_TRANSFORMATION_MANAGER.without(MergeLinears()) + logger = logging.get_logger(__name__) @@ -335,11 +337,9 @@ def parallelize(self): traced = symbolic_trace_pipelined_model(self) torch.nn.functional.one_hot = orig transformations = self.get_transformations() - transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( - self.ipu_config.optimization_level - ) + transformations += TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) composition = compose(*transformations) - non_reversible_composition = DEFAULT_TRANSFORMATION_MANAGER.compose_non_reversible_transformations( + non_reversible_composition = TRANSFORMATION_MANAGER.compose_non_reversible_transformations( self.ipu_config.optimization_level ) traced = composition(traced) @@ -355,9 +355,7 @@ def deparallelize(self): super().deparallelize() self.change_modules_for_ipu(True) transformations = self.get_transformations() - transformations += DEFAULT_TRANSFORMATION_MANAGER.get_reversible_transformations( - self.ipu_config.optimization_level - ) + transformations += TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) composition = compose(*transformations) self = composition(self, reverse=True) return self From 89f04c8d9155d33b57585575fa4a3732f2e027a7 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 21 Oct 2022 12:42:46 +0200 Subject: [PATCH 25/33] Add missing files --- optimum/graphcore/fx/__init__.py | 31 +++++ .../graphcore/fx/transformation_manager.py | 118 ++++++++++++++++++ .../graphcore/models/bart/modeling_bart.py | 4 +- .../models/wav2vec2/modeling_wav2vec2.py | 4 +- tests/test_examples.py | 6 +- 5 files changed, 158 insertions(+), 5 deletions(-) create mode 100644 optimum/graphcore/fx/__init__.py create mode 100644 optimum/graphcore/fx/transformation_manager.py diff --git a/optimum/graphcore/fx/__init__.py b/optimum/graphcore/fx/__init__.py new file mode 100644 index 000000000..6db684224 --- /dev/null +++ b/optimum/graphcore/fx/__init__.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .transformation_manager import DEFAULT_TRANSFORMATION_MANAGER, TransformationManager # noqa +from .transformations import ( # noqa + AddPoptorchBlock, + AddPoptorchBlocksInSeries, + AutoCast, + ClipValues, + ClipValuesSymmetric, + LinearToSerializedLinear, + OutlineAttribute, + RecomputationCheckpoint, + ShareEmbeddingComputation, + TieWeights, + TupleOutput, + VocabEmbeddingToSerializedEmbedding, +) +from .utils import symbolic_trace_pipelined_model # noqa diff --git a/optimum/graphcore/fx/transformation_manager.py b/optimum/graphcore/fx/transformation_manager.py new file mode 100644 index 000000000..c80c21ed6 --- /dev/null +++ b/optimum/graphcore/fx/transformation_manager.py @@ -0,0 +1,118 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines the class that manages which transformations to apply to which model, according to some optimization level.""" +import copy +import functools +from typing import Iterator, List, Tuple, Union + +import torch + +from ...fx.optimization import ( + ChangeTrueDivToMulByInverse, + MergeLinears, + ReversibleTransformation, + Transformation, + compose, +) +from .transformations import ClipValues, ClipValuesSymmetric + + +class TransformationManager: + def __init__(self, *transformations: Tuple[int, "Transformation"]): + self._signatures = { + 0: set(), + 1: set(), + 2: set(), + } + self._transformations = { + 0: [], + 1: [], + 2: [], + } + self.register(*transformations) + + def without(self, *args: Transformation) -> "TransformationManager": + clone = copy.deepcopy(self) + clone.unregister(*args) + return clone + + def register(self, *transformations: Tuple[int, Transformation]): + for (opt_level, t) in transformations: + for k, signatures in self._signatures.items(): + if t.signature in signatures: + raise RuntimeError( + f"The transformation {t} has already been registered at optimization level = {k}." + ) + self._signatures[opt_level].add(t.signature) + self._transformations[opt_level].append(t) + + def unregister(self, *transformations: Transformation): + for transformation_to_unregister in transformations: + level = None + sig = transformation_to_unregister.signature + for opt_level, signatures in self._signatures.items(): + if sig in signatures: + level = opt_level + signatures.remove(sig) + if level is not None: + idx_to_pop = None + for idx, t in enumerate(self._transformations[level]): + if t.signature == sig: + idx_to_pop = idx + break + self._transformations[level].pop(idx_to_pop) + + def _check_optimization_level(self, optimization_level): + if optimization_level not in [0, 1, 2]: + raise ValueError(f"The optimization level must be either 0, 1 or 2, but {optimization_level} was given.") + + def _get_transformations( + self, optimization_level: int, as_list: bool = False + ) -> Union[Iterator[Transformation], List[Transformation]]: + self._check_optimization_level(optimization_level) + # iterator = itertools.chain(self._transformations[i] for i in range(optimization_level + 1) if self._transformations[i]) + iterator = functools.reduce( + lambda x, y: x + y, (self._transformations[i] for i in range(optimization_level + 1)), [] + ) + return iterator if as_list is False else list(iterator) + + def get_transformations(self, optimization_level: int) -> List[Transformation]: + return self._get_transformations(optimization_level, as_list=True) + + def get_non_reversible_transformations(self, optimization_level: int) -> List[Transformation]: + return [ + t for t in self._get_transformations(optimization_level) if not isinstance(t, ReversibleTransformation) + ] + + def get_reversible_transformations(self, optimization_level: int) -> List[ReversibleTransformation]: + return [t for t in self._get_transformations(optimization_level) if isinstance(t, ReversibleTransformation)] + + def compose_transformations(self, optimization_level: int) -> Transformation: + return compose(self._get_transformations(optimization_level)) + + def compose_non_reversible_transformations(self, optimization_level: int) -> Transformation: + return compose(*self.get_non_reversible_transformations(optimization_level)) + + def compose_reversible_transformations(self, optimization_level: int) -> ReversibleTransformation: + return compose(*self.get_reversible_transformations(optimization_level)) + + +DEFAULT_TRANSFORMATION_MANAGER = TransformationManager( + (1, ChangeTrueDivToMulByInverse()), + (1, MergeLinears()), + # (1, FuseBiasInLinear()), + (1, ClipValuesSymmetric(1e4, exclude_targets=("view",))), + (1, ClipValues(1e-4, float("inf"), include_targets=(torch.nn.LayerNorm,))), +) diff --git a/optimum/graphcore/models/bart/modeling_bart.py b/optimum/graphcore/models/bart/modeling_bart.py index 6cb6a3943..e3967efa6 100644 --- a/optimum/graphcore/models/bart/modeling_bart.py +++ b/optimum/graphcore/models/bart/modeling_bart.py @@ -243,7 +243,7 @@ def deparallelize(self): @register(BartForConditionalGeneration) class PipelinedBartForConditionalGeneration( - GenerationMethodsMixin, BartForConditionalGeneration, PipelineMixin, IPUGenerationMixin + GenerationMethodsMixin, BartForConditionalGeneration, BartPipelineMixin, IPUGenerationMixin ): def get_transformations(self): log_insertions = self.ipu_config.log_insertions @@ -343,7 +343,7 @@ def forward( @register(BartForSequenceClassification) -class PipelinedBartForSequenceClassification(BartForSequenceClassification, PipelineMixin): +class PipelinedBartForSequenceClassification(BartForSequenceClassification, BartPipelineMixin): def get_transformations(self): log_insertions = self.ipu_config.log_insertions layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) diff --git a/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py b/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py index a62192d3f..fed853cc2 100755 --- a/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py @@ -466,7 +466,6 @@ def _add_begin_block(self, module, name, ipu_id): def get_transformations(self): log_insertions = self.ipu_config.log_insertions layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) - print(layer_ipu) transformations = super().get_transformations() start_idx = self.config.num_feat_extract_layers + 2 transformations.append( @@ -492,7 +491,8 @@ def parallelize(self): self.change_wav2vec2_encoder_class(False) self.change_wav2vec2_adapter_class(False) self.change_conv_eps(False) - return super().parallelize() + traced = super().parallelize() + return traced def deparallelize(self): """ diff --git a/tests/test_examples.py b/tests/test_examples.py index 5427b4413..3f7ac1098 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -263,6 +263,10 @@ def _create_command_line( "device_iterations=1", f"inference_device_iterations={inference_device_iterations}", f"gradient_accumulation_steps={gradient_accumulation_steps}", + # TODO: only testing examples without any optimization, since it can make training harder (from what was + # observed with previous testing). This will need investigation, so only validating the "vanilla" + # pipelined models for now. + "optimization_level=0", ] ) @@ -563,7 +567,7 @@ class SpeechRecognitionExampleTester( TASK_NAME = "common_voice" DATASET_CONFIG_NAME = "tr" TRAIN_BATCH_SIZE = 1 - GRADIENT_ACCUMULATION_STEPS = 8 + GRADIENT_ACCUMULATION_STEPS = 16 EVAL_BATCH_SIZE = 1 NUM_EPOCHS = 15 # Here we are evaluating against the loss because it can take a long time to have wer < 1.0 From 523f2fccd9b633173b559dae34abc3423bd0c3f1 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 21 Oct 2022 15:46:33 +0200 Subject: [PATCH 26/33] Format --- .../graphcore/fx/transformation_manager.py | 19 ++++++++--- optimum/graphcore/fx/transformations.py | 2 ++ optimum/graphcore/fx/utils.py | 12 +++++++ optimum/graphcore/ipu_configuration.py | 24 +++++++------- optimum/graphcore/modeling_utils.py | 32 +++++++++++-------- .../graphcore/models/bart/modeling_bart.py | 24 +++++++------- .../graphcore/models/bert/modeling_bert.py | 3 ++ .../models/convnext/modeling_convnext.py | 3 ++ .../models/deberta/modeling_deberta.py | 3 +- .../models/distilbert/modeling_distilbert.py | 2 ++ .../graphcore/models/gpt2/modeling_gpt2.py | 2 ++ .../graphcore/models/hubert/ipu_layer_drop.py | 2 -- .../models/hubert/modeling_hubert.py | 9 +++--- .../models/lxmert/modeling_lxmert.py | 3 ++ .../models/roberta/modeling_roberta.py | 2 ++ optimum/graphcore/models/t5/modeling_t5.py | 22 +++++++------ optimum/graphcore/models/vit/modeling_vit.py | 3 ++ .../models/wav2vec2/modeling_wav2vec2.py | 2 ++ optimum/graphcore/trainer.py | 24 +++++++------- optimum/graphcore/trainer_seq2seq.py | 5 +-- tests/test_ipu_configuration.py | 2 +- 21 files changed, 128 insertions(+), 72 deletions(-) diff --git a/optimum/graphcore/fx/transformation_manager.py b/optimum/graphcore/fx/transformation_manager.py index c80c21ed6..65b3e80fb 100644 --- a/optimum/graphcore/fx/transformation_manager.py +++ b/optimum/graphcore/fx/transformation_manager.py @@ -12,7 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Defines the class that manages which transformations to apply to which model, according to some optimization level.""" +"""Defines the class that manages which transformations to apply according to some optimization level.""" + import copy import functools from typing import Iterator, List, Tuple, Union @@ -82,7 +83,6 @@ def _get_transformations( self, optimization_level: int, as_list: bool = False ) -> Union[Iterator[Transformation], List[Transformation]]: self._check_optimization_level(optimization_level) - # iterator = itertools.chain(self._transformations[i] for i in range(optimization_level + 1) if self._transformations[i]) iterator = functools.reduce( lambda x, y: x + y, (self._transformations[i] for i in range(optimization_level + 1)), [] ) @@ -99,14 +99,23 @@ def get_non_reversible_transformations(self, optimization_level: int) -> List[Tr def get_reversible_transformations(self, optimization_level: int) -> List[ReversibleTransformation]: return [t for t in self._get_transformations(optimization_level) if isinstance(t, ReversibleTransformation)] + def _compose_transformations( + self, optimization_level: int, transformations: List[Transformation] + ) -> Transformation: + return compose(*transformations) if transformations else lambda x: x + def compose_transformations(self, optimization_level: int) -> Transformation: - return compose(self._get_transformations(optimization_level)) + return self._compose_transformations(optimization_level, self.get_transformations(optimization_level)) def compose_non_reversible_transformations(self, optimization_level: int) -> Transformation: - return compose(*self.get_non_reversible_transformations(optimization_level)) + return self._compose_transformations( + optimization_level, self.get_non_reversible_transformations(optimization_level) + ) def compose_reversible_transformations(self, optimization_level: int) -> ReversibleTransformation: - return compose(*self.get_reversible_transformations(optimization_level)) + return self._compose_transformations( + optimization_level, self.get_reversible_transformations(optimization_level) + ) DEFAULT_TRANSFORMATION_MANAGER = TransformationManager( diff --git a/optimum/graphcore/fx/transformations.py b/optimum/graphcore/fx/transformations.py index 9a476e643..d4a261940 100644 --- a/optimum/graphcore/fx/transformations.py +++ b/optimum/graphcore/fx/transformations.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""All the parallelization-related transformations.""" + import collections import operator import re diff --git a/optimum/graphcore/fx/utils.py b/optimum/graphcore/fx/utils.py index 43f61b067..b5790af7c 100644 --- a/optimum/graphcore/fx/utils.py +++ b/optimum/graphcore/fx/utils.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Utilties related to FX.""" import inspect import math from typing import TYPE_CHECKING, Callable, Dict, List, Optional @@ -188,12 +189,23 @@ def symbolic_trace_with_pipelined_tracer( def cast_traced_model_to_proper_class(model: torch.nn.Module, traced: torch.fx.GraphModule): + """Casts the traced `torch.fx.GraphModule` to the original class of the traced model.""" type_ = type(f"Traced{model.__class__.__name__}", (torch.fx.GraphModule, model.__class__), {}) traced.__class__ = type_ traced.recompile() def symbolic_trace_pipelined_model(pipelined_model: PipelineMixin) -> PipelineMixin: + """ + Traces a pipelined model and casts the traced model to the original class of the model. + + Args: + pipelined_model ([`~PipelineMixin`]): + The pipelined model. + + Returns: + [`~PipelineMixin`]: The traced model. + """ if isinstance(pipelined_model, torch.fx.GraphModule): return pipelined_model diff --git a/optimum/graphcore/ipu_configuration.py b/optimum/graphcore/ipu_configuration.py index 9fd5440d6..64107dc76 100644 --- a/optimum/graphcore/ipu_configuration.py +++ b/optimum/graphcore/ipu_configuration.py @@ -1,17 +1,18 @@ # coding=utf-8 -# Copyright 2021 The HuggingFace Team. All rights reserved. +# Copyright 2021 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Defines the class handling the configuration of the IPUs.""" import copy import json @@ -49,6 +50,8 @@ class IPUConfig(BaseConfig): **Note: This is an experimental feature and may not behave as expected.** executable_cache_dir (`str`, *optional*, defaults to `""`): Enables caching the compile executables to a directory. + log_insertions (`bool`, *optional*, defaults to `False`): + Whether the block insertion should be logged during model parallelization. optimization_level (`int`, *optional*, defaults to 1): The optimization level to apply to the model before compilation. Three values are allowed: - 0: No optimization is performed on the graph. @@ -165,7 +168,6 @@ def __init__(self, **kwargs): self.execute_encoder_on_cpu_for_generation = kwargs.pop("execute_encoder_on_cpu_for_generation", False) self.log_insertions = kwargs.pop("log_insertions", False) - self.optimization_level = kwargs.pop("optimization_level", 1) def _prepare_config_attribute_for_pod_type( diff --git a/optimum/graphcore/modeling_utils.py b/optimum/graphcore/modeling_utils.py index 23934e2b8..a69dedc69 100644 --- a/optimum/graphcore/modeling_utils.py +++ b/optimum/graphcore/modeling_utils.py @@ -1,20 +1,22 @@ -# Copyright 2021 The HuggingFace Team. All rights reserved. +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilty modules, functions and classes for pipelined models.""" import copy import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import torch import torch.nn.functional as F @@ -22,12 +24,14 @@ import poptorch from optimum.utils import logging -from transformers import AutoConfig, PreTrainedModel -from transformers.modeling_outputs import ModelOutput from .ipu_configuration import IPUConfig +if TYPE_CHECKING: + from transformers import PreTrainedModel + + logger = logging.get_logger(__name__) _PRETRAINED_TO_PIPELINED_REGISTRY = {} @@ -75,7 +79,7 @@ def to_pipelined(model: nn.Module, ipu_config: IPUConfig, force: bool = False): class PipelineMixin: @classmethod - def from_transformers(cls, model: PreTrainedModel, ipu_config: IPUConfig): + def from_transformers(cls, model: "PreTrainedModel", ipu_config: IPUConfig): """ Creates a pipeline model from a [`~transformers.PreTrainedModel`]. diff --git a/optimum/graphcore/models/bart/modeling_bart.py b/optimum/graphcore/models/bart/modeling_bart.py index e3967efa6..bf559beac 100644 --- a/optimum/graphcore/models/bart/modeling_bart.py +++ b/optimum/graphcore/models/bart/modeling_bart.py @@ -1,16 +1,19 @@ -# Copyright 2021 The HuggingFace Team. All rights reserved. +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BART model.""" + from typing import List, Optional, Tuple, Union import torch @@ -73,7 +76,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1 - expanded_mask inverted_mask = -float("inf") * inverted_mask - # inverted_mask = * inverted_mask return inverted_mask diff --git a/optimum/graphcore/models/bert/modeling_bert.py b/optimum/graphcore/models/bert/modeling_bert.py index b29806306..7a337755f 100644 --- a/optimum/graphcore/models/bert/modeling_bert.py +++ b/optimum/graphcore/models/bert/modeling_bert.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright (c) 2021 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""BERT model.""" + from typing import Optional, Tuple, Union import torch diff --git a/optimum/graphcore/models/convnext/modeling_convnext.py b/optimum/graphcore/models/convnext/modeling_convnext.py index 19be99e97..177ceda25 100644 --- a/optimum/graphcore/models/convnext/modeling_convnext.py +++ b/optimum/graphcore/models/convnext/modeling_convnext.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright (c) 2021 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""ConvNeXt model.""" + import torch from torch import nn diff --git a/optimum/graphcore/models/deberta/modeling_deberta.py b/optimum/graphcore/models/deberta/modeling_deberta.py index a57025197..91afc6c6e 100644 --- a/optimum/graphcore/models/deberta/modeling_deberta.py +++ b/optimum/graphcore/models/deberta/modeling_deberta.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright (c) 2021 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""DeBERTa model.""" import math import operator @@ -41,7 +43,6 @@ DEFAULT_TRANSFORMATION_MANAGER, AddPoptorchBlock, AddPoptorchBlocksInSeries, - AutoCast, OutlineAttribute, RecomputationCheckpoint, VocabEmbeddingToSerializedEmbedding, diff --git a/optimum/graphcore/models/distilbert/modeling_distilbert.py b/optimum/graphcore/models/distilbert/modeling_distilbert.py index 03ddb53e3..ac99179ce 100644 --- a/optimum/graphcore/models/distilbert/modeling_distilbert.py +++ b/optimum/graphcore/models/distilbert/modeling_distilbert.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright (c) 2021 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""DistilBERT model.""" import math from typing import Optional, Tuple, Union diff --git a/optimum/graphcore/models/gpt2/modeling_gpt2.py b/optimum/graphcore/models/gpt2/modeling_gpt2.py index 4fea246bf..b4401c5db 100644 --- a/optimum/graphcore/models/gpt2/modeling_gpt2.py +++ b/optimum/graphcore/models/gpt2/modeling_gpt2.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright (c) 2021 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""GPT-2 model.""" import math from typing import Optional, Tuple, Union diff --git a/optimum/graphcore/models/hubert/ipu_layer_drop.py b/optimum/graphcore/models/hubert/ipu_layer_drop.py index 494abd5bd..200644a5c 100644 --- a/optimum/graphcore/models/hubert/ipu_layer_drop.py +++ b/optimum/graphcore/models/hubert/ipu_layer_drop.py @@ -12,14 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ These are the same blocks as in the original implementation in transformers, but with a traceable implementation of LayerDrop. """ import torch -from torch.nn import functional as F from transformers.modeling_outputs import BaseModelOutput from transformers.models.hubert.modeling_hubert import HubertEncoder, HubertEncoderStableLayerNorm diff --git a/optimum/graphcore/models/hubert/modeling_hubert.py b/optimum/graphcore/models/hubert/modeling_hubert.py index a36a31893..a7125dc81 100644 --- a/optimum/graphcore/models/hubert/modeling_hubert.py +++ b/optimum/graphcore/models/hubert/modeling_hubert.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright (c) 2021 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,7 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""HuBERT model.""" + from transformers import HubertForSequenceClassification +from transformers.models.hubert.modeling_hubert import HubertEncoder, HubertEncoderStableLayerNorm from ....fx.optimization import MergeLinears, compose from ....utils import logging @@ -23,6 +27,7 @@ symbolic_trace_pipelined_model, ) from ...modeling_utils import PipelineMixin, get_layer_ipu, register +from .ipu_layer_drop import IPUHubertEncoder, IPUHubertEncoderStableLayerNorm logger = logging.get_logger(__name__) @@ -37,10 +42,6 @@ def change_hubert_encoder_class(self, restore: bool): Args: restore: whether to restore the encoder to its original version or not. """ - from transformers.models.hubert.modeling_hubert import HubertEncoder, HubertEncoderStableLayerNorm - - from .ipu_layer_drop import IPUHubertEncoder, IPUHubertEncoderStableLayerNorm - if self.config.do_stable_layer_norm: new_cls = HubertEncoderStableLayerNorm if restore else IPUHubertEncoderStableLayerNorm else: diff --git a/optimum/graphcore/models/lxmert/modeling_lxmert.py b/optimum/graphcore/models/lxmert/modeling_lxmert.py index 9d476198a..7efb6826a 100644 --- a/optimum/graphcore/models/lxmert/modeling_lxmert.py +++ b/optimum/graphcore/models/lxmert/modeling_lxmert.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright (c) 2021 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""LXMERT model.""" + from typing import Optional, Tuple, Union import torch diff --git a/optimum/graphcore/models/roberta/modeling_roberta.py b/optimum/graphcore/models/roberta/modeling_roberta.py index 562b49a28..cae51e470 100644 --- a/optimum/graphcore/models/roberta/modeling_roberta.py +++ b/optimum/graphcore/models/roberta/modeling_roberta.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright (c) 2021 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""RoBERTa model.""" from typing import Optional, Tuple, Union diff --git a/optimum/graphcore/models/t5/modeling_t5.py b/optimum/graphcore/models/t5/modeling_t5.py index 00653ebc1..07074a29f 100644 --- a/optimum/graphcore/models/t5/modeling_t5.py +++ b/optimum/graphcore/models/t5/modeling_t5.py @@ -1,16 +1,18 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""T5 model.""" import warnings from typing import Optional, Tuple, Union diff --git a/optimum/graphcore/models/vit/modeling_vit.py b/optimum/graphcore/models/vit/modeling_vit.py index 47047e78a..13bdef5d4 100644 --- a/optimum/graphcore/models/vit/modeling_vit.py +++ b/optimum/graphcore/models/vit/modeling_vit.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright (c) 2021 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""ViT model.""" + import transformers from ....fx.optimization import compose diff --git a/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py b/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py index fed853cc2..2a58fd92e 100755 --- a/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/graphcore/models/wav2vec2/modeling_wav2vec2.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright (c) 2022 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Wav2Vec2 model.""" from typing import Optional, Tuple, Union diff --git a/optimum/graphcore/trainer.py b/optimum/graphcore/trainer.py index 3deffb594..8d050292a 100644 --- a/optimum/graphcore/trainer.py +++ b/optimum/graphcore/trainer.py @@ -1,16 +1,18 @@ -# copyright 2021 the huggingface team. all rights reserved. +# coding=utf-8 +# Copyright 2019 the huggingface team. all rights reserved. # -# licensed under the apache license, version 2.0 (the "license"); -# you may not use this file except in compliance with the license. -# you may obtain a copy of the license at +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at # -# http://www.apache.org/licenses/license-2.0 +# http://www.apache.org/licenses/license-2.0 # -# unless required by applicable law or agreed to in writing, software -# distributed under the license is distributed on an "as is" basis, -# without warranties or conditions of any kind, either express or implied. -# see the license for the specific language governing permissions and -# limitations under the license. +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. +"""The IPUTrainer class, handling everything to perform training and evaluation of models on IPUs.""" import collections import copy @@ -39,7 +41,7 @@ import torch from packaging import version from torch import nn, optim -from torch.utils.data import Dataset, RandomSampler, SequentialSampler, SubsetRandomSampler +from torch.utils.data import Dataset, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler import poptorch diff --git a/optimum/graphcore/trainer_seq2seq.py b/optimum/graphcore/trainer_seq2seq.py index daf2dca2b..cf9000fd1 100644 --- a/optimum/graphcore/trainer_seq2seq.py +++ b/optimum/graphcore/trainer_seq2seq.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,9 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""IPUTrainer that can handle seq2seq models.""" -import inspect -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn diff --git a/tests/test_ipu_configuration.py b/tests/test_ipu_configuration.py index 55b78c162..57743f8ce 100644 --- a/tests/test_ipu_configuration.py +++ b/tests/test_ipu_configuration.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2021 HuggingFace Inc. +# Copyright 2021 HuggingFace Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 84d55bbd5b7548901274f6b54db9c05ad97079b5 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 21 Oct 2022 17:27:24 +0200 Subject: [PATCH 27/33] some stuff --- optimum/graphcore/fx/transformation_manager.py | 5 +++-- optimum/graphcore/ipu_configuration.py | 3 ++- optimum/graphcore/models/bart/modeling_bart.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/optimum/graphcore/fx/transformation_manager.py b/optimum/graphcore/fx/transformation_manager.py index 65b3e80fb..0812989c0 100644 --- a/optimum/graphcore/fx/transformation_manager.py +++ b/optimum/graphcore/fx/transformation_manager.py @@ -122,6 +122,7 @@ def compose_reversible_transformations(self, optimization_level: int) -> Reversi (1, ChangeTrueDivToMulByInverse()), (1, MergeLinears()), # (1, FuseBiasInLinear()), - (1, ClipValuesSymmetric(1e4, exclude_targets=("view",))), - (1, ClipValues(1e-4, float("inf"), include_targets=(torch.nn.LayerNorm,))), + # Those change the computation, but are actually needed for fp16 stability. + (0, ClipValuesSymmetric(1e4, exclude_targets=("view",))), + (0, ClipValues(1e-4, float("inf"), include_targets=(torch.nn.LayerNorm,))), ) diff --git a/optimum/graphcore/ipu_configuration.py b/optimum/graphcore/ipu_configuration.py index 64107dc76..287504e22 100644 --- a/optimum/graphcore/ipu_configuration.py +++ b/optimum/graphcore/ipu_configuration.py @@ -168,7 +168,8 @@ def __init__(self, **kwargs): self.execute_encoder_on_cpu_for_generation = kwargs.pop("execute_encoder_on_cpu_for_generation", False) self.log_insertions = kwargs.pop("log_insertions", False) - self.optimization_level = kwargs.pop("optimization_level", 1) + # TODO: set that to one, once everything is working. + self.optimization_level = kwargs.pop("optimization_level", 0) def _prepare_config_attribute_for_pod_type( self, config_attribute_name: str, config_attribute: Union[Any, Dict[str, Any]], pod_type: Optional[str] diff --git a/optimum/graphcore/models/bart/modeling_bart.py b/optimum/graphcore/models/bart/modeling_bart.py index bf559beac..dfacc46b8 100644 --- a/optimum/graphcore/models/bart/modeling_bart.py +++ b/optimum/graphcore/models/bart/modeling_bart.py @@ -47,7 +47,7 @@ FLOAT16_LIMIT = 1e4 TRANSFORMATION_MANAGER = DEFAULT_TRANSFORMATION_MANAGER.without(ClipValuesSymmetric(1e4, exclude_targets=("view",))) -TRANSFORMATION_MANAGER.register((1, ClipValuesSymmetric(10000, exclude_targets=("view",)))) +TRANSFORMATION_MANAGER.register((0, ClipValuesSymmetric(10000, exclude_targets=("view",)))) def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): From 8b2e8fb3e9c4aed15788e82bbae9eaf67e6bcf9c Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 28 Oct 2022 16:10:10 +0200 Subject: [PATCH 28/33] Fix BART --- optimum/graphcore/fx/transformations.py | 40 +++++++++++++++---- .../graphcore/models/bart/modeling_bart.py | 12 +----- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/optimum/graphcore/fx/transformations.py b/optimum/graphcore/fx/transformations.py index d4a261940..87868d395 100644 --- a/optimum/graphcore/fx/transformations.py +++ b/optimum/graphcore/fx/transformations.py @@ -376,22 +376,38 @@ class VocabEmbeddingToSerializedEmbedding(ReversibleTransformation): """ def __init__(self, name_regex: Optional[str] = None): - self.name_regex = re.compile(name_regex) if name_regex else None + self.name_regex_for_module = re.compile(name_regex) if name_regex else None + self.name_regex_for_function = re.compile(name_regex.replace(".", "_")) if name_regex else None def transform(self, graph_module: "GraphModule") -> "GraphModule": embedding_nodes = [] for node in graph_module.graph.nodes: - if node.op != "call_module": - continue - match = re.match(self.name_regex, node.target) if self.name_regex is not None else True - if match and isinstance(graph_module.get_submodule(node.target), torch.nn.Embedding): + if node.op == "call_module": + if self.name_regex_for_module is not None and not re.match(self.name_regex_for_module, node.target): + continue + elif not isinstance(graph_module.get_submodule(node.target), torch.nn.Embedding): + continue + embedding_nodes.append(node) + elif node.op == "call_function": + if self.name_regex_for_function is not None and not re.match(self.name_regex_for_function, node.name): + continue + elif node.target is not torch.nn.functional.embedding: + continue embedding_nodes.append(node) # We assume the vocab embedding to be the embedding with the maximum number of embeddings. if not embedding_nodes: raise RuntimeError("Could not find any embedding node") - embedding_node = max(embedding_nodes, key=lambda node: graph_module.get_submodule(node.target).num_embeddings) + def sort_nodes_function(node): + if node.op == "call_module": + return graph_module.get_submodule(node.target).num_embeddings + return node.args[1].shape[1] + + embedding_node = max(embedding_nodes, key=sort_nodes_function) + if embedding_node.op == "call_function": + raise NotImplementedError("VocabEmbeddingToSerializedEmbedding does not support torch.nn.functional.embedding yet.") + split = embedding_node.target.rsplit(".", maxsplit=1) if len(split) == 1: split = [None] + split @@ -504,6 +520,12 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule": class ShareEmbeddingComputation(Transformation): + def __init__(self, name_regex: Optional[str] = None, allowed_embedding_classes: Union[Tuple[Type], Type] = (torch.nn.Embedding, SerializedEmbedding)): + self.name_regex = re.compile(name_regex) if name_regex else None + self.allowed_embedding_classes = allowed_embedding_classes + if not isinstance(self.allowed_embedding_classes, tuple): + self.allowed_embedding_classes = (self.allowed_embedding_classes,) + def _find_nodes_to_move(self, graph_module, embedding_input_node, shared_embedding_node): nodes_before_embedding_input_node = set() for node in graph_module.graph.nodes: @@ -535,7 +557,11 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule": candidates = collections.defaultdict(list) embedding_nodes = collections.defaultdict(list) for node in graph_module.graph.nodes: - if node.op == "call_module" and isinstance(graph_module.get_submodule(node.target), torch.nn.Embedding): + if node.op == "call_module": + if self.name_regex is not None and not re.match(self.name_regex, node.target): + continue + elif not isinstance(graph_module.get_submodule(node.target), self.allowed_embedding_classes): + continue candidates[node.target].append(node.args[0]) embedding_nodes[node.target].append(node) diff --git a/optimum/graphcore/models/bart/modeling_bart.py b/optimum/graphcore/models/bart/modeling_bart.py index dfacc46b8..a056034f4 100644 --- a/optimum/graphcore/models/bart/modeling_bart.py +++ b/optimum/graphcore/models/bart/modeling_bart.py @@ -252,10 +252,6 @@ def get_transformations(self): layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) transformations = [ AddPoptorchBlock("Embedding", 0, "model.shared", log_insertions=log_insertions), - # AddPoptorchBlock("Embedding", 0, "model.encoder.embed_positions"), - # AddPoptorchBlock("Embedding", 0, "model.encoder.layernorm_embedding"), - # AddPoptorchBlock("Embedding", 0, "model.decoder.embed_positions"), - # AddPoptorchBlock("Embedding", 0, "model.decoder.layernorm_embedding"), AddPoptorchBlocksInSeries( "Encoder", layer_ipu[: self.config.encoder_layers], @@ -351,10 +347,6 @@ def get_transformations(self): layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) transformations = [ AddPoptorchBlock("Embedding", 0, "model.shared", log_insertions=log_insertions), - # AddPoptorchBlock("Embedding", 0, "model.encoder.embed_positions"), - # AddPoptorchBlock("Embedding", 0, "model.encoder.layernorm_embedding"), - # AddPoptorchBlock("Embedding", 0, "model.decoder.embed_positions"), - # AddPoptorchBlock("Embedding", 0, "model.decoder.layernorm_embedding"), AddPoptorchBlocksInSeries( "Encoder", layer_ipu[: self.config.encoder_layers], @@ -383,8 +375,8 @@ def get_transformations(self): if not isinstance(self, torch.fx.GraphModule): if self.ipu_config.embedding_serialization_factor > 1: - transformations.append(VocabEmbeddingToSerializedEmbedding()) - transformations += [ShareEmbeddingComputation()] + transformations.append(VocabEmbeddingToSerializedEmbedding("model.shared")) + transformations += [ShareEmbeddingComputation("model.shared")] return transformations def forward( From c94d930fecf814f6fbd9a5f0709b6fd5e1cfd654 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 7 Nov 2022 14:20:05 +0100 Subject: [PATCH 29/33] changes --- optimum/graphcore/fx/utils.py | 59 ++++++++++++------- .../models/deberta/modeling_deberta.py | 52 ++++++++++++---- .../graphcore/models/gpt2/modeling_gpt2.py | 6 +- optimum/graphcore/trainer.py | 6 +- optimum/graphcore/training_args.py | 2 - 5 files changed, 88 insertions(+), 37 deletions(-) diff --git a/optimum/graphcore/fx/utils.py b/optimum/graphcore/fx/utils.py index b5790af7c..39f8d96ed 100644 --- a/optimum/graphcore/fx/utils.py +++ b/optimum/graphcore/fx/utils.py @@ -25,7 +25,7 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES, ) -from transformers.utils.fx import HFTracer, get_concrete_args +from transformers.utils.fx import HFAttribute, HFProxy, HFTracer, get_concrete_args from ..modeling_utils import PipelineMixin @@ -34,11 +34,30 @@ from transformers import PreTrainedModel +# TODO: keep this until transformers >= 4.23.2 +class GCProxy(HFProxy): + + @property + def dtype(self): + return self.__getattr__("dtype") + + def __getattr__(self, k): + if k == "_metadata": + return self.__getattribute__(k) + # note: not added to the graph yet, if this is a method call + # we peephole optimize to the method invocation + hf_attribute = HFAttribute(self, k) + if hasattr(self, "_metadata"): + hf_attribute.install_metadata(getattr(self._metadata, k)) + return hf_attribute + + class PipelinedTracer(HFTracer): # TODO: keep this until transformers >= 4.23.2 _TORCH_METHODS_TO_PATCH = list(HFTracer._TORCH_METHODS_TO_PATCH) _TORCH_METHODS_TO_PATCH.append("clamp") _TORCH_METHODS_TO_PATCH.append("rand") + _TORCH_METHODS_TO_PATCH.append("finfo") """ Tracer that enables tracing and transforming models to run them on IPUs. Compared to the HFTracer, this one adds the following features: @@ -79,8 +98,9 @@ def proxy(self, node): # it is easier to use this one, and equivalent. node.parent_module_qualified_name = self.current_module_qualified_name[-1] node.parent_module_type = self.current_module_type[-1] - proxy = super().proxy(node) - return proxy + return GCProxy(node, self) + # return gc_proxy + return super().proxy(node) def call_module(self, m, forward, args, kwargs): # Could be done in a "cleaner" fashion by inlining the content of Tracer.call_module. @@ -98,22 +118,22 @@ def call_module(self, m, forward, args, kwargs): return proxy def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None): - if self.root_is_in_half_precision: - float32_dtype_in_args = any(a is torch.float32 for a in args) - float32_dtype_in_kwargs = kwargs.get("dtype", None) is torch.float32 - node_types_to_inspect = [ - ("call_method", "to"), - ("call_function", torch.full), - ] - torch_methods_to_patched_version = { - orig: wrapped for (orig, wrapped) in self.patched_torch_methods.values() - } - for (k, t) in node_types_to_inspect: - if kind == k and target == torch_methods_to_patched_version.get(t, t): - if float32_dtype_in_args: - args = tuple(a if a is not torch.float32 else torch.float16 for a in args) - if float32_dtype_in_kwargs: - kwargs["dtype"] = torch.float16 + # if self.root_is_in_half_precision: + # float32_dtype_in_args = any(a is torch.float32 for a in args) + # float32_dtype_in_kwargs = kwargs.get("dtype", None) is torch.float32 + # node_types_to_inspect = [ + # ("call_method", "to"), + # ("call_function", torch.full), + # ] + # torch_methods_to_patched_version = { + # orig: wrapped for (orig, wrapped) in self.patched_torch_methods.values() + # } + # for (k, t) in node_types_to_inspect: + # if kind == k and target == torch_methods_to_patched_version.get(t, t): + # if float32_dtype_in_args: + # args = tuple(a if a is not torch.float32 else torch.float16 for a in args) + # if float32_dtype_in_kwargs: + # kwargs["dtype"] = torch.float16 return super().create_proxy( kind, target, args, kwargs, name=name, type_expr=type_expr, proxy_factory_fn=proxy_factory_fn ) @@ -149,7 +169,6 @@ def symbolic_trace_with_pipelined_tracer( model: PipelineMixin, input_names: Optional[List[str]] = None, ) -> torch.fx.GraphModule: - """ Performs symbolic tracing on the model. diff --git a/optimum/graphcore/models/deberta/modeling_deberta.py b/optimum/graphcore/models/deberta/modeling_deberta.py index 91afc6c6e..56b169e75 100644 --- a/optimum/graphcore/models/deberta/modeling_deberta.py +++ b/optimum/graphcore/models/deberta/modeling_deberta.py @@ -37,7 +37,7 @@ ) from transformers.utils.fx import _gen_constructor_wrapper -from ....fx.optimization import MergeLinears, compose +from ....fx.optimization import MergeLinears, ReversibleTransformation, compose from ....utils import logging from ...fx import ( DEFAULT_TRANSFORMATION_MANAGER, @@ -46,6 +46,8 @@ OutlineAttribute, RecomputationCheckpoint, VocabEmbeddingToSerializedEmbedding, + LinearToSerializedLinear, + TieWeights, symbolic_trace_pipelined_model, ) from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register @@ -107,7 +109,32 @@ def _get_rel_embedding(self): return self.rel_embeddings.weight + 0.0 if self.relative_attention else None -gather_last_dim = FastGatherLastDim() +def faster_gather_last_dim(input, dim, index, *args, **kwargs): + target = torch.zeros_like(index).to(input.dtype) + target.requires_grad_() + o = poptorch.custom_op( + [input, index], + "FastGatherLastDim", + "poptorch.custom_ops", + 1, + example_outputs=[target], + attributes={"axis": -1}, + ) + return o[0] + + +class ChangeTorchGather(ReversibleTransformation): + def transform(self, graph_module): + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target is torch.gather: + node.target = faster_gather_last_dim + return graph_module + + def reverse(self, graph_module): + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target is faster_gather_last_dim: + node.target = torch.gather + return graph_module class IPUDisentangledSelfAttention(DisentangledSelfAttention): @@ -124,8 +151,6 @@ class IPUDisentangledSelfAttention(DisentangledSelfAttention): def __init__(self, config): super().__init__(config) self.xsoftmax = XSoftmax(-1) - # self.gather_last_dim = FastGatherLastDim() - self.gather_last_dim = gather_last_dim def forward( self, @@ -248,7 +273,8 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd index = c2p_pos.expand( [query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)] ) - c2p_att = self.gather_last_dim(c2p_att, index) + # c2p_att = gather_last_dim(c2p_att, index) + c2p_att = torch.gather(c2p_att, -1, index) score += c2p_att # position->content @@ -263,12 +289,12 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) index = p2c_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2)) - p2c_att = self.gather_last_dim(p2c_att, index).transpose(-1, -2) + p2c_att = torch.gather(p2c_att, -1, index).transpose(-1, -2) if query_layer.size(-2) != key_layer.size(-2): pos_index = relative_pos[:, :, :, 0].unsqueeze(-1) index = pos_index.expand(pos_index, p2c_att, key_layer) - p2c_att = self.gather_last_dim(p2c_att, index) + p2c_att = torch.gather(p2c_att, -1, index) score += p2c_att return score @@ -283,7 +309,6 @@ def change_modules_for_ipu(self, restore: bool): del mod.xsoftmax else: mod.add_module("xsoftmax", XSoftmax(-1)) - mod.add_module("gather_last_dim", FastGatherLastDim()) if restore: if isinstance(mod, nn.Dropout): mod.__class__ = StableDropout @@ -302,10 +327,10 @@ def change_modules_for_ipu(self, restore: bool): def get_transformations(self): log_insertions = self.ipu_config.log_insertions layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) + # TODO: handle DebertaForMaskedLM transformations = [ AddPoptorchBlock("Embedding", 0, "deberta.embeddings", log_insertions=log_insertions), OutlineAttribute("deberta.embeddings.LayerNorm", "Embedding"), - AddPoptorchBlock("Before Encoder", 0, "deberta.encoder", log_insertions=log_insertions), AddPoptorchBlocksInSeries( "Encoder", layer_ipu, r"deberta.encoder.layer.[0-9]+", log_insertions=log_insertions ), @@ -322,7 +347,13 @@ def get_transformations(self): ) ) if self.ipu_config.embedding_serialization_factor > 1: - transformations.append(VocabEmbeddingToSerializedEmbedding()) + if isinstance(self, DebertaForMaskedLM): + transformations += [ + LinearToSerializedLinear("cls.predictions.decoder"), + TieWeights("deberta.embeddings.word_embeddings", "cls.predictions.decoder"), + ] + else: + transformations.append(VocabEmbeddingToSerializedEmbedding()) return transformations def parallelize(self): @@ -339,6 +370,7 @@ def parallelize(self): torch.nn.functional.one_hot = orig transformations = self.get_transformations() transformations += TRANSFORMATION_MANAGER.get_reversible_transformations(self.ipu_config.optimization_level) + transformations.append(ChangeTorchGather()) composition = compose(*transformations) non_reversible_composition = TRANSFORMATION_MANAGER.compose_non_reversible_transformations( self.ipu_config.optimization_level diff --git a/optimum/graphcore/models/gpt2/modeling_gpt2.py b/optimum/graphcore/models/gpt2/modeling_gpt2.py index b4401c5db..ce672243a 100644 --- a/optimum/graphcore/models/gpt2/modeling_gpt2.py +++ b/optimum/graphcore/models/gpt2/modeling_gpt2.py @@ -131,7 +131,7 @@ def deparallelize(self): @register(GPT2LMHeadModel) -class PipelinedGPT2LMHeadModel(GPT2LMHeadModel, GPT2PipelineMixin): +class PipelinedGPT2LMHeadModel(GPT2PipelineMixin, GPT2LMHeadModel): def get_transformations(self): log_insertions = self.ipu_config.log_insertions layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) @@ -241,7 +241,7 @@ def forward( @register(GPT2ForSequenceClassification) -class PipelinedGPT2ForSequenceClassification(GPT2ForSequenceClassification, GPT2PipelineMixin): +class PipelinedGPT2ForSequenceClassification(GPT2PipelineMixin, GPT2ForSequenceClassification): def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -290,5 +290,5 @@ def forward( @register(GPT2ForTokenClassification) -class PipelinedGPT2ForTokenClassification(GPT2ForTokenClassification, GPT2PipelineMixin): +class PipelinedGPT2ForTokenClassification(GPT2PipelineMixin, GPT2ForTokenClassification): pass diff --git a/optimum/graphcore/trainer.py b/optimum/graphcore/trainer.py index 8d050292a..b499ed9f9 100644 --- a/optimum/graphcore/trainer.py +++ b/optimum/graphcore/trainer.py @@ -282,6 +282,8 @@ def __init__( if args.ipu_config_overrides: logger.info(f"Overriding IPU config: {args.ipu_config_overrides}") self.ipu_config.update_from_string(args.ipu_config_overrides) + if self.args.gradient_accumulation_steps is None: + self.args.gradient_accumulation_steps = self.ipu_config.gradient_accumulation_steps self.ipu_config.seed = self.args.seed self.opts = self.ipu_config.to_options(compile_only=args.compile_only) self.eval_opts = self.ipu_config.to_options(for_inference=True, compile_only=args.compile_only) @@ -1116,7 +1118,7 @@ def _inner_training_loop( logger.info(f" Num Epochs = {num_train_epochs}") logger.info(f" Instantaneous batch size per device = {batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Gradient Accumulation steps = {self.ipu_config.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps}") self.state.epoch = 0 @@ -1208,7 +1210,7 @@ def _inner_training_loop( steps_in_epoch = ( len(epoch_iterator) if has_length(train_dataloader) - else args.max_steps * args.gradient_accumulation_steps + else args.max_steps * self.ipu_config.gradient_accumulation_steps ) self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) diff --git a/optimum/graphcore/training_args.py b/optimum/graphcore/training_args.py index d8ab4456f..80bb79e0e 100644 --- a/optimum/graphcore/training_args.py +++ b/optimum/graphcore/training_args.py @@ -750,8 +750,6 @@ def __post_init__(self): override_str = [] if self.gradient_accumulation_steps is not None: override_str.append(f"gradient_accumulation_steps={self.gradient_accumulation_steps}") - else: - self.gradient_accumulation_steps = 1 if self.auto_loss_scaling: override_str.append(f"auto_loss_scaling={self.auto_loss_scaling}") From a582a465f1d85c38369a8ebbe9f410a598230700 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 8 Nov 2022 15:45:20 +0100 Subject: [PATCH 30/33] GPT-2 Fix --- optimum/graphcore/fx/transformation_manager.py | 3 ++- optimum/graphcore/fx/transformations.py | 14 +++++++++++--- optimum/graphcore/fx/utils.py | 1 - .../graphcore/models/deberta/modeling_deberta.py | 4 ++-- optimum/graphcore/models/gpt2/modeling_gpt2.py | 10 +++++++--- 5 files changed, 22 insertions(+), 10 deletions(-) diff --git a/optimum/graphcore/fx/transformation_manager.py b/optimum/graphcore/fx/transformation_manager.py index 0812989c0..be96b3bc5 100644 --- a/optimum/graphcore/fx/transformation_manager.py +++ b/optimum/graphcore/fx/transformation_manager.py @@ -16,6 +16,7 @@ import copy import functools +import operator from typing import Iterator, List, Tuple, Union import torch @@ -123,6 +124,6 @@ def compose_reversible_transformations(self, optimization_level: int) -> Reversi (1, MergeLinears()), # (1, FuseBiasInLinear()), # Those change the computation, but are actually needed for fp16 stability. - (0, ClipValuesSymmetric(1e4, exclude_targets=("view",))), + (0, ClipValuesSymmetric(1e4, include_targets=(torch.add, torch.mul, operator.add, operator.mul))), (0, ClipValues(1e-4, float("inf"), include_targets=(torch.nn.LayerNorm,))), ) diff --git a/optimum/graphcore/fx/transformations.py b/optimum/graphcore/fx/transformations.py index 87868d395..22cd54fcb 100644 --- a/optimum/graphcore/fx/transformations.py +++ b/optimum/graphcore/fx/transformations.py @@ -204,7 +204,9 @@ def __init__( ): if clip_value < 0: raise ValueError(f"The provided clip value must be equal or greater than 0, but here {clip_value}.") - return super().__init__(-clip_value, clip_value, exclude_targets=exclude_targets) + return super().__init__( + -clip_value, clip_value, include_targets=include_targets, exclude_targets=exclude_targets + ) class OutlineAttribute(ReversibleTransformation): @@ -406,7 +408,9 @@ def sort_nodes_function(node): embedding_node = max(embedding_nodes, key=sort_nodes_function) if embedding_node.op == "call_function": - raise NotImplementedError("VocabEmbeddingToSerializedEmbedding does not support torch.nn.functional.embedding yet.") + raise NotImplementedError( + "VocabEmbeddingToSerializedEmbedding does not support torch.nn.functional.embedding yet." + ) split = embedding_node.target.rsplit(".", maxsplit=1) if len(split) == 1: @@ -520,7 +524,11 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule": class ShareEmbeddingComputation(Transformation): - def __init__(self, name_regex: Optional[str] = None, allowed_embedding_classes: Union[Tuple[Type], Type] = (torch.nn.Embedding, SerializedEmbedding)): + def __init__( + self, + name_regex: Optional[str] = None, + allowed_embedding_classes: Union[Tuple[Type], Type] = (torch.nn.Embedding, SerializedEmbedding), + ): self.name_regex = re.compile(name_regex) if name_regex else None self.allowed_embedding_classes = allowed_embedding_classes if not isinstance(self.allowed_embedding_classes, tuple): diff --git a/optimum/graphcore/fx/utils.py b/optimum/graphcore/fx/utils.py index 39f8d96ed..c876941cc 100644 --- a/optimum/graphcore/fx/utils.py +++ b/optimum/graphcore/fx/utils.py @@ -36,7 +36,6 @@ # TODO: keep this until transformers >= 4.23.2 class GCProxy(HFProxy): - @property def dtype(self): return self.__getattr__("dtype") diff --git a/optimum/graphcore/models/deberta/modeling_deberta.py b/optimum/graphcore/models/deberta/modeling_deberta.py index 56b169e75..64a2c099b 100644 --- a/optimum/graphcore/models/deberta/modeling_deberta.py +++ b/optimum/graphcore/models/deberta/modeling_deberta.py @@ -43,11 +43,11 @@ DEFAULT_TRANSFORMATION_MANAGER, AddPoptorchBlock, AddPoptorchBlocksInSeries, + LinearToSerializedLinear, OutlineAttribute, RecomputationCheckpoint, - VocabEmbeddingToSerializedEmbedding, - LinearToSerializedLinear, TieWeights, + VocabEmbeddingToSerializedEmbedding, symbolic_trace_pipelined_model, ) from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register diff --git a/optimum/graphcore/models/gpt2/modeling_gpt2.py b/optimum/graphcore/models/gpt2/modeling_gpt2.py index ce672243a..027cc8b81 100644 --- a/optimum/graphcore/models/gpt2/modeling_gpt2.py +++ b/optimum/graphcore/models/gpt2/modeling_gpt2.py @@ -37,6 +37,7 @@ symbolic_trace_pipelined_model, ) from ...modeling_utils import PipelineMixin, get_layer_ipu, register +from .optimized_gpt2_attn import OptimizedGPT2Attention logger = logging.get_logger(__name__) @@ -69,7 +70,7 @@ def get_transformations(self): layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) transformations = [ AddPoptorchBlock("Token Embedding", 0, "transformer.wte", log_insertions=log_insertions), - AddPoptorchBlock("Position Embedding", 1, "transformer.wtp", log_insertions=log_insertions), + AddPoptorchBlock("Position Embedding", 0, "transformer.wpe", log_insertions=log_insertions), OutlineAttribute("transformer.ln_f", "LayerNorm"), AddPoptorchBlocksInSeries("Layer", layer_ipu, r"transformer.h.[0-9]+", log_insertions=log_insertions), # Only one of the following AddPoptorchBlock, will actually add a block. @@ -84,7 +85,7 @@ def get_transformations(self): ) ) if self.ipu_config.embedding_serialization_factor > 1: - transformations.append(VocabEmbeddingToSerializedEmbedding()) + transformations.append(VocabEmbeddingToSerializedEmbedding("transformer.wte")) return transformations @@ -96,6 +97,9 @@ def parallelize(self): - Adds recomputation checkpoints """ PipelineMixin.parallelize(self) + if not isinstance(self, torch.fx.GraphModule): + for layer in self.transformer.h: + layer.attn.__class__ = OptimizedGPT2Attention if self.ipu_config.embedding_serialization_factor > 1: self.resize_vocab(False) traced = symbolic_trace_pipelined_model(self) @@ -137,7 +141,7 @@ def get_transformations(self): layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu) transformations = [ AddPoptorchBlock("Token Embedding", 0, "transformer.wte", log_insertions=log_insertions), - AddPoptorchBlock("Position Embedding", 1, "transformer.wtp", log_insertions=log_insertions), + AddPoptorchBlock("Position Embedding", 0, "transformer.wpe", log_insertions=log_insertions), OutlineAttribute("transformer.ln_f", "LayerNorm"), AddPoptorchBlocksInSeries("Layer", layer_ipu, r"transformer.h.[0-9]+", log_insertions=log_insertions), AddPoptorchBlock("LM Head", 0, "lm_head", log_insertions=log_insertions), From 43f20011a9537d3600bb4279e708de3252fa7681 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 10 Nov 2022 16:07:57 +0100 Subject: [PATCH 31/33] commit for diff --- optimum/graphcore/models/deberta/modeling_deberta.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/optimum/graphcore/models/deberta/modeling_deberta.py b/optimum/graphcore/models/deberta/modeling_deberta.py index 64a2c099b..7157791e4 100644 --- a/optimum/graphcore/models/deberta/modeling_deberta.py +++ b/optimum/graphcore/models/deberta/modeling_deberta.py @@ -268,12 +268,10 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd pos_key_layer = self.pos_proj(rel_embeddings) pos_key_layer = self.transpose_for_scores(pos_key_layer) c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2)) - # c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) - c2p_pos = (relative_pos + att_span).clamp(0, att_span * 2 - 1) + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) index = c2p_pos.expand( [query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)] ) - # c2p_att = gather_last_dim(c2p_att, index) c2p_att = torch.gather(c2p_att, -1, index) score += c2p_att From 08dc5d18afb436515544d50c765dddb1d1c41043 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 14 Nov 2022 14:43:58 +0100 Subject: [PATCH 32/33] Fix transformation => Deberta compiles --- optimum/graphcore/fx/transformations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/graphcore/fx/transformations.py b/optimum/graphcore/fx/transformations.py index 22cd54fcb..da252a20b 100644 --- a/optimum/graphcore/fx/transformations.py +++ b/optimum/graphcore/fx/transformations.py @@ -176,7 +176,7 @@ def _clip_node_args(self, args): elif isinstance(args, dict): return {name: self._clip_node_args(arg) for name, arg in args.items()} elif isinstance(args, (float, int)): - return min(max(args, self.min_value), self.max_value) + return type(args)(min(max(args, self.min_value), self.max_value)) else: return args From 586fde2f00de6b25afcfd9d2cbef939afd5408fd Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 14 Nov 2022 17:58:39 +0100 Subject: [PATCH 33/33] Fixed BART --- optimum/graphcore/fx/transformations.py | 15 +++++++++++++-- optimum/graphcore/models/bart/modeling_bart.py | 14 ++++++++++++-- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/optimum/graphcore/fx/transformations.py b/optimum/graphcore/fx/transformations.py index da252a20b..2f8bb94dd 100644 --- a/optimum/graphcore/fx/transformations.py +++ b/optimum/graphcore/fx/transformations.py @@ -164,11 +164,13 @@ def __init__( max_value: float, include_targets: Optional[Tuple[Union[str, Callable]]] = None, exclude_targets: Optional[Tuple[Union[str, Callable]]] = None, + cast_to_type: Optional[Union[Type, Callable[[Type], Type]]] = None, ): self.min_value = min_value self.max_value = max_value self.include_targets = include_targets if include_targets is not None else () self.exclude_targets = exclude_targets if exclude_targets is not None else () + self.cast_to_type = cast_to_type def _clip_node_args(self, args): if isinstance(args, (tuple, list, set)): @@ -176,7 +178,11 @@ def _clip_node_args(self, args): elif isinstance(args, dict): return {name: self._clip_node_args(arg) for name, arg in args.items()} elif isinstance(args, (float, int)): - return type(args)(min(max(args, self.min_value), self.max_value)) + if self.cast_to_type is None: + cast_to_type = type(args) + else: + cast_to_type = self.cast_to_type + return cast_to_type(min(max(args, self.min_value), self.max_value)) else: return args @@ -201,11 +207,16 @@ def __init__( clip_value: float, include_targets: Optional[Tuple[Union[str, Callable]]] = None, exclude_targets: Optional[Tuple[Union[str, Callable]]] = None, + cast_to_type: Optional[Union[Type, Callable[[Type], Type]]] = None, ): if clip_value < 0: raise ValueError(f"The provided clip value must be equal or greater than 0, but here {clip_value}.") return super().__init__( - -clip_value, clip_value, include_targets=include_targets, exclude_targets=exclude_targets + -clip_value, + clip_value, + include_targets=include_targets, + exclude_targets=exclude_targets, + cast_to_type=cast_to_type, ) diff --git a/optimum/graphcore/models/bart/modeling_bart.py b/optimum/graphcore/models/bart/modeling_bart.py index a056034f4..e35c43be5 100644 --- a/optimum/graphcore/models/bart/modeling_bart.py +++ b/optimum/graphcore/models/bart/modeling_bart.py @@ -14,6 +14,7 @@ # limitations under the License. """BART model.""" +import operator from typing import List, Optional, Tuple, Union import torch @@ -46,8 +47,17 @@ FLOAT16_LIMIT = 1e4 -TRANSFORMATION_MANAGER = DEFAULT_TRANSFORMATION_MANAGER.without(ClipValuesSymmetric(1e4, exclude_targets=("view",))) -TRANSFORMATION_MANAGER.register((0, ClipValuesSymmetric(10000, exclude_targets=("view",)))) +TRANSFORMATION_MANAGER = DEFAULT_TRANSFORMATION_MANAGER.without( + ClipValuesSymmetric(1e4, include_targets=(torch.add, torch.mul, operator.add, operator.mul)) +) +TRANSFORMATION_MANAGER.register( + ( + 0, + ClipValuesSymmetric( + 10000, include_targets=(torch.add, torch.mul, operator.add, operator.mul), cast_to_type=int + ), + ) +) def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):